1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2025-11-20 22:13:15 +01:00
parent d5da453de5
commit 8b7ea8110b
7 changed files with 2297 additions and 0 deletions

View File

@@ -0,0 +1,382 @@
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_folder /raid/yiyi/new-model-vid \
--output_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \
--transformer_type 480p_i2v \
--dtype fp32
"""
import argparse
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
import pathlib
from diffusers import HunyuanVideo15Transformer3DModel
TRANSFORMER_CONFIGS = {
"480p_i2v": {
"in_channels": 65,
"out_channels": 32,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 54,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 1,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"text_embed_dim": 3584,
"text_embed_2_dim": 1472,
"image_embed_dim": 1152,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"use_meanflow": False,
},
}
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
"""
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
"""
converted_state_dict = {}
# 1. time_embed.timestep_embedder <- time_in
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_in.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_in.mlp.0.bias"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_in.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_in.mlp.2.bias"
)
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.0.bias")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
)
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.2.bias")
)
# 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = (
original_state_dict.pop("txt_in.c_embedder.linear_1.weight")
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = (
original_state_dict.pop("txt_in.c_embedder.linear_1.bias")
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = (
original_state_dict.pop("txt_in.c_embedder.linear_2.weight")
)
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = (
original_state_dict.pop("txt_in.c_embedder.linear_2.bias")
)
# 4. context_embedder.proj_in <- txt_in.input_embedder
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop(
"txt_in.input_embedder.weight"
)
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
# 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
num_refiner_blocks = 2
for i in range(num_refiner_blocks):
block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}."
# norm1
converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight")
converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias")
# Split self_attn_qkv into to_q, to_k, to_v
qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight")
qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias")
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
# self_attn_proj -> attn.to_out.0
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"{orig_prefix}self_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"{orig_prefix}self_attn_proj.bias"
)
# norm2
converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight")
converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias")
# mlp -> ff
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias")
# adaLN_modulation -> norm_out
converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}adaLN_modulation.1.weight"
)
converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}adaLN_modulation.1.bias"
)
# 6. context_embedder_2 <- byt5_in
converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
# 7. image_embedder <- vision_in
converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight")
converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias")
converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight")
converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias")
converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight")
converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias")
converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight")
converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias")
# 8. x_embedder <- img_in
converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
# 9. cond_type_embed <- cond_type_embedding
converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight")
# 10. transformer_blocks <- double_blocks
num_layers = 54
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
orig_prefix = f"double_blocks.{i}."
# norm1 (img_mod)
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mod.linear.bias"
)
# norm1_context (txt_mod)
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mod.linear.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mod.linear.bias"
)
# img attention (to_q, to_k, to_v)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_v.bias"
)
# img attention qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_k_norm.weight"
)
# img attention output projection
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"{orig_prefix}img_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"{orig_prefix}img_attn_proj.bias"
)
# txt attention (add_q_proj, add_k_proj, add_v_proj)
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q.bias"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k.weight"
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k.bias"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_v.weight"
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_v.bias"
)
# txt attention qk norm
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_q_norm.weight"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_k_norm.weight"
)
# txt attention output projection
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_attn_proj.bias"
)
# norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False)
# So we skip them
# img_mlp -> ff
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"{orig_prefix}img_mlp.fc2.bias"
)
# txt_mlp -> ff_context
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc1.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc1.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
f"{orig_prefix}txt_mlp.fc2.bias"
)
# 11. norm_out and proj_out <- final_layer
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop(
"final_layer.adaLN_modulation.1.weight"
))
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias"))
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
return converted_state_dict
def load_sharded_safetensors(dir: pathlib.Path):
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
state_dict = {}
for path in file_paths:
state_dict.update(load_file(path))
return state_dict
def load_original_state_dict(args):
if args.original_state_dict_repo_id is not None:
model_dir = snapshot_download(
args.original_state_dict_repo_id,
repo_type="model",
allow_patterns="transformer/" + args.transformer_type + "/*"
)
elif args.original_state_dict_folder is not None:
model_dir = pathlib.Path(args.original_state_dict_folder)
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
model_dir = pathlib.Path(model_dir)
model_dir = model_dir / "transformer" / args.transformer_type
return load_sharded_safetensors(model_dir)
def convert_transformer(args):
original_state_dict = load_original_state_dict(args)
config = TRANSFORMER_CONFIGS[args.transformer_type]
with init_empty_weights():
transformer = HunyuanVideo15Transformer3DModel(**config)
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
transformer.load_state_dict(state_dict, strict=True, assign=True)
return transformer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
)
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Folder name of the original state dict")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument(
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
)
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args)
transformer = transformer.to(dtype=dtype)
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

View File

@@ -189,6 +189,7 @@ else:
"AutoencoderKLHunyuanImage",
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -223,6 +224,7 @@ else:
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"HunyuanVideo15Transformer3DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
@@ -903,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -939,6 +942,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
HunyuanVideo15Transformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,

View File

@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -83,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
@@ -143,6 +145,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImage,
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -191,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HunyuanVideo15Transformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -7,6 +7,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi

View File

@@ -0,0 +1,968 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideo15CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
return self.conv(hidden_states)
class HunyuanVideo15RMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanVideo15AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = HunyuanVideo15RMS_norm(in_channels, images=False)
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
@staticmethod
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
"""Prepare a causal attention mask for 3D videos.
Args:
n_frame (int): Number of frames (temporal length).
n_hw (int): Product of height and width.
dtype: Desired mask dtype.
device: Device for the mask.
batch_size (int, optional): If set, expands for batch.
Returns:
torch.Tensor: Causal attention mask.
"""
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.norm(x)
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
batch_size, channels, frames, height, width = query.shape
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
attention_mask = self.prepare_causal_attention_mask(frames, height * width, query.dtype, query.device, batch_size=batch_size)
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
# batch_size, 1, frames * height * width, channels
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
x = self.proj_out(x)
return x + identity
class HunyuanVideo15Upsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
@staticmethod
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
Args:
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
r1: temporal upsampling factor
r2: height upsampling factor
r3: width upsampling factor
"""
b, packed_c, f, h, w = tensor.shape
factor = r1 * r2 * r3
c = packed_c // factor
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h_first = h[:, :, :1, :, :]
h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2)
h_first = h_first[:, : h_first.shape[1] // 2]
h_next = h[:, :, 1:, :, :]
h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2)
x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
x_next = x[:, :, 1:, :, :]
x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2)
x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
return h + shortcut
class HunyuanVideo15Downsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
@staticmethod
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
"""
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
This packs spatial/temporal dimensions into channels (opposite of upsample)
"""
b, c, packed_f, packed_h, packed_w = tensor.shape
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
def forward(self, x: torch.Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
h_first = h[:, :, :1, :, :]
h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2)
h_first = torch.cat([h_first, h_first], dim=1)
h_next = h[:, :, 1:, :, :]
h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2)
B, C, T, H, W = x_first.shape
x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
x_next = x[:, :, 1:, :, :]
x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2)
B, C, T, H, W = x_next.shape
x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class HunyuanVideo15ResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
non_linearity: str = "swish",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False)
self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False)
self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class HunyuanVideo15MidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
add_attention: bool = True,
) -> None:
super().__init__()
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanVideo15ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
]
attentions = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(HunyuanVideo15AttnBlock(in_channels))
else:
attentions.append(None)
resnets.append(
HunyuanVideo15ResnetBlock(
in_channels=in_channels,
out_channels=in_channels,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanVideo15DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample_out_channels: Optional[int] = None,
add_temporal_downsample: int = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideo15ResnetBlock(
in_channels=in_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample_out_channels is not None:
self.downsamplers = nn.ModuleList(
[
HunyuanVideo15Downsample(
out_channels,
out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanVideo15UpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample_out_channels: Optional[int] = None,
add_temporal_upsample: bool = True,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideo15ResnetBlock(
in_channels=input_channels,
out_channels=out_channels,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample_out_channels is not None:
self.upsamplers = nn.ModuleList(
[
HunyuanVideo15Upsample(
out_channels,
out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanVideo15Encoder3D(nn.Module):
r"""
3D vae encoder for HunyuanImageRefiner.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 16,
downsample_match_channel: bool = True,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.group_size = block_out_channels[-1] // self.out_channels
self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
output_channel = block_out_channels[i]
if not add_spatial_downsample:
down_block = HunyuanVideo15DownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=None,
add_temporal_downsample=False,
)
input_channel = output_channel
else:
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
down_block = HunyuanVideo15DownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
downsample_out_channels=downsample_out_channels,
add_temporal_downsample=add_temporal_downsample,
)
input_channel = downsample_out_channels
self.down_blocks.append(down_block)
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1])
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
batch_size, _, frame, height, width = hidden_states.shape
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states += short_cut
return hidden_states
class HunyuanVideo15Decoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
"""
def __init__(
self,
in_channels: int = 32,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
upsample_match_channel: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.in_channels = in_channels
self.out_channels = out_channels
self.repeat = block_out_channels[0] // self.in_channels
self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0])
# up
input_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
output_channel = block_out_channels[i]
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
if add_spatial_upsample or add_temporal_upsample:
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
up_block = HunyuanVideo15UpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=upsample_out_channels,
add_temporal_upsample=add_temporal_upsample,
)
input_channel = upsample_out_channels
else:
up_block = HunyuanVideo15UpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
upsample_out_channels=None,
add_temporal_upsample=False,
)
input_channel = output_channel
self.up_blocks.append(up_block)
# out
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 32,
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
layers_per_block: int = 2,
spatial_compression_ratio: int = 16,
temporal_compression_ratio: int = 4,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
scaling_factor: float = 1.03682,
) -> None:
super().__init__()
self.encoder = HunyuanVideo15Encoder3D(
in_channels=in_channels,
out_channels=latent_channels * 2,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
downsample_match_channel=downsample_match_channel,
)
self.decoder = HunyuanVideo15Decoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
layers_per_block=layers_per_block,
temporal_compression_ratio=temporal_compression_ratio,
spatial_compression_ratio=spatial_compression_ratio,
upsample_match_channel=upsample_match_channel,
)
self.spatial_compression_ratio = spatial_compression_ratio
self.temporal_compression_ratio = temporal_compression_ratio
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal tile height and width in latent space
self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
self.tile_overlap_factor = 0.25
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_latent_min_height: Optional[int] = None,
tile_latent_min_width: Optional[int] = None,
tile_overlap_factor: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_latent_min_height (`int`, *optional*):
The minimum height required for a latent to be separated into tiles across the height dimension.
tile_latent_min_width (`int`, *optional*):
The minimum width required for a latent to be separated into tiles across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
x = self.encoder(x)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z)
dec = self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6
row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, _, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192
row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec

View File

@@ -44,3 +44,4 @@ if is_torch_available():
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_animate import WanAnimateTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -0,0 +1,937 @@
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideo15AttnProcessor2_0:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
assert False # YiYi Notes: remove this condition if this code path is never used
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
else:
assert False
# YiYi Notes: remove this condition if this code path is never used
if attn.norm_k is not None:
key = attn.norm_k(key)
else:
assert False
# YiYi Notes: remove this condition if this code path is never used
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if attn.add_q_proj is None and encoder_hidden_states is not None:
assert False # YiYi Notes: remove this condition if this code path is never used
query = torch.cat(
[
apply_rotary_emb(
query[:, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
sequence_dim=1,
),
query[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
key = torch.cat(
[
apply_rotary_emb(
key[:, : -encoder_hidden_states.shape[1]],
image_rotary_emb,
sequence_dim=1,
),
key[:, -encoder_hidden_states.shape[1] :],
],
dim=1,
)
else:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
else:
assert False # YiYi Notes: remove this condition if this code path is never used
batch_size, seq_len, heads, dim = query.shape
print(f" query.shape: {query.shape}")
print(f" attention_mask.shape: {attention_mask.shape}")
attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True)
print(f" attention_mask.shape: {attention_mask.shape}")
attention_mask = attention_mask.bool()
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# 5. Attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
print(f" hidden_states.shape: {hidden_states.shape}")
print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}")
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, : -encoder_hidden_states.shape[1]],
hidden_states[:, -encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class HunyuanVideoPatchEmbed(nn.Module):
def __init__(
self,
patch_size: Union[int, Tuple[int, int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
super().__init__()
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
return hidden_states
class HunyuanVideoAdaNorm(nn.Module):
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
super().__init__()
out_features = out_features or 2 * in_features
self.linear = nn.Linear(in_features, out_features)
self.nonlinearity = nn.SiLU()
def forward(
self, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
temb = self.linear(self.nonlinearity(temb))
gate_msa, gate_mlp = temb.chunk(2, dim=1)
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
return gate_msa, gate_mlp
class HunyuanVideo15TimeEmbedding(nn.Module):
r"""
Time embedding for HunyuanVideo 1.5.
Supports standard timestep embedding and optional reference timestep embedding
for MeanFlow-based super-resolution models.
Args:
embedding_dim (`int`):
The dimension of the output embedding.
use_meanflow (`bool`, defaults to `False`):
Whether to support reference timestep embeddings for temporal consistency.
Set to `True` for super-resolution models.
"""
def __init__(
self,
embedding_dim: int,
use_meanflow: bool = False,
):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_meanflow = use_meanflow
self.time_proj_r = None
self.timestep_embedder_r = None
if use_meanflow:
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self,
timestep: torch.Tensor,
timestep_r: Optional[torch.Tensor] = None,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
if timestep_r is not None:
timesteps_proj_r = self.time_proj_r(timestep_r)
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
timesteps_emb = timesteps_emb + timesteps_emb_r
return timesteps_emb
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
)
gate_msa, gate_mlp = self.norm_out(temb)
hidden_states = hidden_states + attn_output * gate_msa
ff_output = self.ff(self.norm2(hidden_states))
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
class HunyuanVideoIndividualTokenRefiner(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
self.refiner_blocks = nn.ModuleList(
[
HunyuanVideoIndividualTokenRefinerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
for _ in range(num_layers)
]
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> None:
self_attn_mask = None
if attention_mask is not None:
# YiYi TODO convert 1D mask to 4d Bx1xLxL
batch_size = attention_mask.shape[0]
seq_len = attention_mask.shape[1]
attention_mask = attention_mask.to(hidden_states.device).bool()
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
for block in self.refiner_blocks:
hidden_states = block(hidden_states, temb, self_attn_mask)
return hidden_states
class HunyuanVideoTokenRefiner(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
attention_head_dim: int,
num_layers: int,
mlp_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
attention_bias: bool = True,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=hidden_size, pooled_projection_dim=in_channels
)
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_layers=num_layers,
mlp_width_ratio=mlp_ratio,
mlp_drop_rate=mlp_drop_rate,
attention_bias=attention_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if attention_mask is None:
pooled_projections = hidden_states.mean(dim=1)
else:
original_dtype = hidden_states.dtype
mask_float = attention_mask.float().unsqueeze(-1)
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
pooled_projections = pooled_projections.to(original_dtype)
temb = self.time_text_embed(timestep, pooled_projections)
print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}")
hidden_states = self.proj_in(hidden_states)
print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
print(f" temb: {temb.shape}, {temb[0,:10]}")
print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}")
print(f" -> token_refiner")
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}")
return hidden_states
class HunyuanVideoRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.rope_dim = rope_dim
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
axes_grids = []
for i in range(3):
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
# original implementation creates it on CPU and then moves it to device. This results in numerical
# differences in layerwise debugging outputs, but visually it is the same.
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
axes_grids.append(grid)
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
grid = torch.stack(grid, dim=0) # [3, W, H, T]
freqs = []
for i in range(3):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
# Copied from diffusers.models.transformers.transformer_hunyuanimage.HunyuanImageByT5TextProjection
class HunyuanVideo15ByT5TextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, out_features: int):
super().__init__()
self.norm = nn.LayerNorm(in_features)
self.linear_1 = nn.Linear(in_features, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, out_features)
self.act_fn = nn.GELU()
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(encoder_hidden_states)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_3(hidden_states)
return hidden_states
class HunyuanVideo15ImageProjection(nn.Module):
def __init__(self, in_channels: int, hidden_size: int):
super().__init__()
self.norm_in = nn.LayerNorm(in_channels)
self.linear_1 = nn.Linear(in_channels, in_channels)
self.act_fn = nn.GELU()
self.linear_2 = nn.Linear(in_channels, hidden_size)
self.norm_out = nn.LayerNorm(hidden_size)
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm_in(image_embeds)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.norm_out(hidden_states)
return hidden_states
class HunyuanVideoTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}")
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}")
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}")
print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}")
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}")
print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
return hidden_states, encoder_hidden_states
class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of dual-stream blocks to use.
num_refiner_layers (`int`, defaults to `2`):
The number of layers of refiner blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
qk_norm (`str`, defaults to `rms_norm`):
The normalization to use for the query and key projections in the attention layers.
guidance_embeds (`bool`, defaults to `True`):
Whether to use guidance embeddings in the model.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
pooled_projection_dim (`int`, defaults to `768`):
The dimension of the pooled projection of the text embeddings.
rope_theta (`float`, defaults to `256.0`):
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
_repeated_blocks = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 20,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: int = 1,
patch_size_t: int = 1,
qk_norm: str = "rms_norm",
text_embed_dim: int = 3584,
text_embed_2_dim: int = 1472,
image_embed_dim: int = 1152,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
use_meanflow: bool = False,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow)
self.cond_type_embed = nn.Embedding(3, inner_dim)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
encoder_attention_mask_2: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
timestep_r: Optional[torch.LongTensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb = self.time_embed(timestep, timestep_r=timestep_r)
hidden_states = self.x_embedder(hidden_states)
# qwen text embedding
print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
print(f" timestep: {timestep}, {timestep[:10]}")
print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}")
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
encoder_hidden_states_cond_emb = self.cond_type_embed(
torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
)
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb
print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
# byt5 text embedding
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
encoder_hidden_states_2_cond_emb = self.cond_type_embed(
torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
)
encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
# image embed
encoder_hidden_states_3 = self.image_embedder(image_embeds)
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
is_t2v = torch.all(image_embeds == 0)
if is_t2v:
encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0
encoder_attention_mask_3 = torch.zeros(
(batch_size, encoder_hidden_states_3.shape[1]),
dtype=encoder_attention_mask.dtype,
device=encoder_attention_mask.device,
)
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}")
else:
encoder_attention_mask_3 = torch.ones(
(batch_size, encoder_hidden_states_3.shape[1]),
dtype=encoder_attention_mask.dtype,
device=encoder_attention_mask.device,
)
encoder_hidden_states_3_cond_emb = self.cond_type_embed(
2 * torch.ones_like(
encoder_hidden_states_3[:, :, 0],
dtype=torch.long,
)
)
encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb
print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
# reorder and combine text tokens: combine valid tokens first, then padding
encoder_attention_mask = encoder_attention_mask.bool()
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
encoder_attention_mask_3 = encoder_attention_mask_3.bool()
new_encoder_hidden_states = []
new_encoder_attention_mask = []
for text, text_mask, text_2, text_mask_2, image, image_mask in zip(
encoder_hidden_states,
encoder_attention_mask,
encoder_hidden_states_2,
encoder_attention_mask_2,
encoder_hidden_states_3,
encoder_attention_mask_3,
):
# Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm]
new_encoder_hidden_states.append(
torch.cat(
[
image[image_mask], # valid image
text_2[text_mask_2], # valid byt5
text[text_mask], # valid mllm
image[~image_mask], # invalid image
torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed)
torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed)
],
dim=0,
)
)
# Apply same reordering to attention masks
new_encoder_attention_mask.append(
torch.cat(
[
image_mask[image_mask],
text_mask_2[text_mask_2],
text_mask[text_mask],
image_mask[~image_mask],
text_mask_2[~text_mask_2],
text_mask[~text_mask],
],
dim=0,
)
)
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}")
print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}")
print(f" temb.shape: {temb.shape}, {temb[0,:10]}")
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
)
# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)