mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add
This commit is contained in:
382
scripts/convert_hunyuan_video1_5_to_diffusers.py
Normal file
382
scripts/convert_hunyuan_video1_5_to_diffusers.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
|
||||
--original_state_dict_folder /raid/yiyi/new-model-vid \
|
||||
--output_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \
|
||||
--transformer_type 480p_i2v \
|
||||
--dtype fp32
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import pathlib
|
||||
from diffusers import HunyuanVideo15Transformer3DModel
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"480p_i2v": {
|
||||
"in_channels": 65,
|
||||
"out_channels": 32,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 54,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"text_embed_dim": 3584,
|
||||
"text_embed_2_dim": 1472,
|
||||
"image_embed_dim": 1152,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
"use_meanflow": False,
|
||||
},
|
||||
}
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
|
||||
"""
|
||||
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
|
||||
"""
|
||||
converted_state_dict = {}
|
||||
|
||||
# 1. time_embed.timestep_embedder <- time_in
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"time_in.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"time_in.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"time_in.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"time_in.mlp.2.bias"
|
||||
)
|
||||
|
||||
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.0.bias")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = (
|
||||
original_state_dict.pop("txt_in.t_embedder.mlp.2.bias")
|
||||
)
|
||||
|
||||
# 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = (
|
||||
original_state_dict.pop("txt_in.c_embedder.linear_1.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = (
|
||||
original_state_dict.pop("txt_in.c_embedder.linear_1.bias")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = (
|
||||
original_state_dict.pop("txt_in.c_embedder.linear_2.weight")
|
||||
)
|
||||
converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = (
|
||||
original_state_dict.pop("txt_in.c_embedder.linear_2.bias")
|
||||
)
|
||||
|
||||
# 4. context_embedder.proj_in <- txt_in.input_embedder
|
||||
converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop(
|
||||
"txt_in.input_embedder.weight"
|
||||
)
|
||||
converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
|
||||
|
||||
# 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
|
||||
num_refiner_blocks = 2
|
||||
for i in range(num_refiner_blocks):
|
||||
block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
|
||||
orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}."
|
||||
|
||||
# norm1
|
||||
converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight")
|
||||
converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias")
|
||||
|
||||
# Split self_attn_qkv into to_q, to_k, to_v
|
||||
qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight")
|
||||
qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias")
|
||||
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
|
||||
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
|
||||
|
||||
# self_attn_proj -> attn.to_out.0
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}self_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}self_attn_proj.bias"
|
||||
)
|
||||
|
||||
# norm2
|
||||
converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight")
|
||||
converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias")
|
||||
|
||||
# mlp -> ff
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias")
|
||||
|
||||
# adaLN_modulation -> norm_out
|
||||
converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}adaLN_modulation.1.bias"
|
||||
)
|
||||
|
||||
# 6. context_embedder_2 <- byt5_in
|
||||
converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
|
||||
converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
|
||||
converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
|
||||
converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
|
||||
converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
|
||||
converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
|
||||
converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
|
||||
converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
|
||||
|
||||
# 7. image_embedder <- vision_in
|
||||
converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight")
|
||||
converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias")
|
||||
converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight")
|
||||
converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias")
|
||||
converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight")
|
||||
converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias")
|
||||
converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight")
|
||||
converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias")
|
||||
|
||||
# 8. x_embedder <- img_in
|
||||
converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
|
||||
converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
|
||||
|
||||
# 9. cond_type_embed <- cond_type_embedding
|
||||
converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight")
|
||||
|
||||
# 10. transformer_blocks <- double_blocks
|
||||
num_layers = 54
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
orig_prefix = f"double_blocks.{i}."
|
||||
|
||||
# norm1 (img_mod)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mod.linear.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mod.linear.bias"
|
||||
)
|
||||
|
||||
# norm1_context (txt_mod)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mod.linear.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mod.linear.bias"
|
||||
)
|
||||
|
||||
# img attention (to_q, to_k, to_v)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_v.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_v.bias"
|
||||
)
|
||||
|
||||
# img attention qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_k_norm.weight"
|
||||
)
|
||||
|
||||
# img attention output projection
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_attn_proj.bias"
|
||||
)
|
||||
|
||||
# txt attention (add_q_proj, add_k_proj, add_v_proj)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_v.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_v.bias"
|
||||
)
|
||||
|
||||
# txt attention qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_k_norm.weight"
|
||||
)
|
||||
|
||||
# txt attention output projection
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_proj.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_attn_proj.bias"
|
||||
)
|
||||
|
||||
# norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False)
|
||||
# So we skip them
|
||||
|
||||
# img_mlp -> ff
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}img_mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# txt_mlp -> ff_context
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
|
||||
f"{orig_prefix}txt_mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# 11. norm_out and proj_out <- final_layer
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop(
|
||||
"final_layer.adaLN_modulation.1.weight"
|
||||
))
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias"))
|
||||
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def load_sharded_safetensors(dir: pathlib.Path):
|
||||
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
|
||||
state_dict = {}
|
||||
for path in file_paths:
|
||||
state_dict.update(load_file(path))
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_original_state_dict(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
model_dir = snapshot_download(
|
||||
args.original_state_dict_repo_id,
|
||||
repo_type="model",
|
||||
allow_patterns="transformer/" + args.transformer_type + "/*"
|
||||
)
|
||||
elif args.original_state_dict_folder is not None:
|
||||
model_dir = pathlib.Path(args.original_state_dict_folder)
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
model_dir = model_dir / "transformer" / args.transformer_type
|
||||
return load_sharded_safetensors(model_dir)
|
||||
|
||||
def convert_transformer(args):
|
||||
original_state_dict = load_original_state_dict(args)
|
||||
|
||||
config = TRANSFORMER_CONFIGS[args.transformer_type]
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideo15Transformer3DModel(**config)
|
||||
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
|
||||
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
||||
|
||||
return transformer
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
|
||||
)
|
||||
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Folder name of the original state dict")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
parser.add_argument(
|
||||
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -189,6 +189,7 @@ else:
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -223,6 +224,7 @@ else:
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanImageTransformer2DModel",
|
||||
"HunyuanVideo15Transformer3DModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
@@ -903,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -939,6 +942,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
Kandinsky5Transformer3DModel,
|
||||
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
@@ -83,6 +84,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
@@ -143,6 +145,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -191,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanVideo15Transformer3DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -7,6 +7,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
|
||||
@@ -0,0 +1,968 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideo15CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
bias: bool = True,
|
||||
pad_mode: str = "replicate",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[0] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[1] // 2,
|
||||
kernel_size[2] - 1,
|
||||
0,
|
||||
)
|
||||
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class HunyuanVideo15RMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = HunyuanVideo15RMS_norm(in_channels, images=False)
|
||||
|
||||
self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
@staticmethod
|
||||
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
||||
"""Prepare a causal attention mask for 3D videos.
|
||||
|
||||
Args:
|
||||
n_frame (int): Number of frames (temporal length).
|
||||
n_hw (int): Product of height and width.
|
||||
dtype: Desired mask dtype.
|
||||
device: Device for the mask.
|
||||
batch_size (int, optional): If set, expands for batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Causal attention mask.
|
||||
"""
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, : (i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
query = self.to_q(x)
|
||||
key = self.to_k(x)
|
||||
value = self.to_v(x)
|
||||
|
||||
batch_size, channels, frames, height, width = query.shape
|
||||
|
||||
query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
|
||||
|
||||
attention_mask = self.prepare_causal_attention_mask(frames, height * width, query.dtype, query.device, batch_size=batch_size)
|
||||
|
||||
x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
# batch_size, 1, frames * height * width, channels
|
||||
|
||||
x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
|
||||
x = self.proj_out(x)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class HunyuanVideo15Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
|
||||
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_upsample = add_temporal_upsample
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
|
||||
|
||||
Args:
|
||||
tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
|
||||
r1: temporal upsampling factor
|
||||
r2: height upsampling factor
|
||||
r3: width upsampling factor
|
||||
"""
|
||||
b, packed_c, f, h, w = tensor.shape
|
||||
factor = r1 * r2 * r3
|
||||
c = packed_c // factor
|
||||
|
||||
tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
|
||||
tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
return tensor.reshape(b, c, f * r1, h * r2, w * r3)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_upsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_upsample:
|
||||
h_first = h[:, :, :1, :, :]
|
||||
h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2)
|
||||
h_first = h_first[:, : h_first.shape[1] // 2]
|
||||
h_next = h[:, :, 1:, :, :]
|
||||
h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2)
|
||||
h = torch.cat([h_first, h_next], dim=2)
|
||||
|
||||
# shortcut computation
|
||||
x_first = x[:, :, :1, :, :]
|
||||
x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2)
|
||||
x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
|
||||
|
||||
x_next = x[:, :, 1:, :, :]
|
||||
x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2)
|
||||
x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = torch.cat([x_first, x_next], dim=2)
|
||||
|
||||
else:
|
||||
h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanVideo15Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
|
||||
super().__init__()
|
||||
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
|
||||
assert out_channels % factor == 0
|
||||
# self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
|
||||
|
||||
self.add_temporal_downsample = add_temporal_downsample
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
@staticmethod
|
||||
def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
|
||||
"""
|
||||
Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
|
||||
|
||||
This packs spatial/temporal dimensions into channels (opposite of upsample)
|
||||
"""
|
||||
b, c, packed_f, packed_h, packed_w = tensor.shape
|
||||
f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
|
||||
|
||||
tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
|
||||
tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
r1 = 2 if self.add_temporal_downsample else 1
|
||||
h = self.conv(x)
|
||||
if self.add_temporal_downsample:
|
||||
h_first = h[:, :, :1, :, :]
|
||||
h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2)
|
||||
h_first = torch.cat([h_first, h_first], dim=1)
|
||||
h_next = h[:, :, 1:, :, :]
|
||||
h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2)
|
||||
h = torch.cat([h_first, h_next], dim=2)
|
||||
|
||||
# shortcut computation
|
||||
x_first = x[:, :, :1, :, :]
|
||||
x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2)
|
||||
B, C, T, H, W = x_first.shape
|
||||
x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
|
||||
x_next = x[:, :, 1:, :, :]
|
||||
x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = x_next.shape
|
||||
x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
shortcut = torch.cat([x_first, x_next], dim=2)
|
||||
else:
|
||||
h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
|
||||
shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
|
||||
B, C, T, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class HunyuanVideo15ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
non_linearity: str = "swish",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False)
|
||||
self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False)
|
||||
self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class HunyuanVideo15MidBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
add_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.add_attention = add_attention
|
||||
|
||||
# There is always at least one resnet
|
||||
resnets = [
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(HunyuanVideo15AttnBlock(in_channels))
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15DownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
downsample_out_channels: Optional[int] = None,
|
||||
add_temporal_downsample: int = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample_out_channels is not None:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15Downsample(
|
||||
out_channels,
|
||||
out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15UpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
upsample_out_channels: Optional[int] = None,
|
||||
add_temporal_upsample: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
HunyuanVideo15ResnetBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample_out_channels is not None:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideo15Upsample(
|
||||
out_channels,
|
||||
out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
||||
|
||||
else:
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Encoder3D(nn.Module):
|
||||
r"""
|
||||
3D vae encoder for HunyuanImageRefiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 64,
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
temporal_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 16,
|
||||
downsample_match_channel: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.group_size = block_out_channels[-1] // self.out_channels
|
||||
|
||||
self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
add_spatial_downsample = i < np.log2(spatial_compression_ratio)
|
||||
output_channel = block_out_channels[i]
|
||||
if not add_spatial_downsample:
|
||||
down_block = HunyuanVideo15DownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=None,
|
||||
add_temporal_downsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
else:
|
||||
add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
|
||||
downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
|
||||
down_block = HunyuanVideo15DownBlock3D(
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
downsample_out_channels=downsample_out_channels,
|
||||
add_temporal_downsample=add_temporal_downsample,
|
||||
)
|
||||
input_channel = downsample_out_channels
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1])
|
||||
|
||||
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
|
||||
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
else:
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
# short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
|
||||
batch_size, _, frame, height, width = hidden_states.shape
|
||||
short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states += short_cut
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Decoder3D(nn.Module):
|
||||
r"""
|
||||
Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
upsample_match_channel: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.repeat = block_out_channels[0] // self.in_channels
|
||||
|
||||
self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0])
|
||||
|
||||
# up
|
||||
input_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
add_spatial_upsample = i < np.log2(spatial_compression_ratio)
|
||||
add_temporal_upsample = i < np.log2(temporal_compression_ratio)
|
||||
if add_spatial_upsample or add_temporal_upsample:
|
||||
upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
|
||||
up_block = HunyuanVideo15UpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=upsample_out_channels,
|
||||
add_temporal_upsample=add_temporal_upsample,
|
||||
)
|
||||
input_channel = upsample_out_channels
|
||||
else:
|
||||
up_block = HunyuanVideo15UpBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
upsample_out_channels=None,
|
||||
add_temporal_upsample=False,
|
||||
)
|
||||
input_channel = output_channel
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# out
|
||||
self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
|
||||
else:
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanVideo-1.5.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
downsample_match_channel: bool = True,
|
||||
upsample_match_channel: bool = True,
|
||||
scaling_factor: float = 1.03682,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = HunyuanVideo15Encoder3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels * 2,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
downsample_match_channel=downsample_match_channel,
|
||||
)
|
||||
|
||||
self.decoder = HunyuanVideo15Decoder3D(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
layers_per_block=layers_per_block,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
upsample_match_channel=upsample_match_channel,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = spatial_compression_ratio
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 256
|
||||
self.tile_sample_min_width = 256
|
||||
|
||||
# The minimal tile height and width in latent space
|
||||
self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
|
||||
self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_latent_min_height: Optional[int] = None,
|
||||
tile_latent_min_width: Optional[int] = None,
|
||||
tile_overlap_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_latent_min_height (`int`, *optional*):
|
||||
The minimum height required for a latent to be separated into tiles across the height dimension.
|
||||
tile_latent_min_width (`int`, *optional*):
|
||||
The minimum width required for a latent to be separated into tiles across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height
|
||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x)
|
||||
|
||||
x = self.encoder(x)
|
||||
return x
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
return dec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
|
||||
blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
|
||||
row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6
|
||||
row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = x[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return moments
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
|
||||
_, _, _, height, width = z.shape
|
||||
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
|
||||
blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
|
||||
row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192
|
||||
row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192
|
||||
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
|
||||
return dec
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -44,3 +44,4 @@ if is_torch_available():
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_animate import WanAnimateTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
937
src/diffusers/models/transformers/transformer_hunyuan_video15.py
Normal file
937
src/diffusers/models/transformers/transformer_hunyuan_video15.py
Normal file
@@ -0,0 +1,937 @@
|
||||
# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideo15AttnProcessor2_0:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
assert False # YiYi Notes: remove this condition if this code path is never used
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
else:
|
||||
assert False
|
||||
# YiYi Notes: remove this condition if this code path is never used
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
else:
|
||||
assert False
|
||||
# YiYi Notes: remove this condition if this code path is never used
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
assert False # YiYi Notes: remove this condition if this code path is never used
|
||||
query = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(
|
||||
query[:, : -encoder_hidden_states.shape[1]],
|
||||
image_rotary_emb,
|
||||
sequence_dim=1,
|
||||
),
|
||||
query[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
key = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(
|
||||
key[:, : -encoder_hidden_states.shape[1]],
|
||||
image_rotary_emb,
|
||||
sequence_dim=1,
|
||||
),
|
||||
key[:, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=1)
|
||||
key = torch.cat([key, encoder_key], dim=1)
|
||||
value = torch.cat([value, encoder_value], dim=1)
|
||||
|
||||
else:
|
||||
assert False # YiYi Notes: remove this condition if this code path is never used
|
||||
|
||||
|
||||
batch_size, seq_len, heads, dim = query.shape
|
||||
print(f" query.shape: {query.shape}")
|
||||
print(f" attention_mask.shape: {attention_mask.shape}")
|
||||
attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True)
|
||||
print(f" attention_mask.shape: {attention_mask.shape}")
|
||||
attention_mask = attention_mask.bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
print(f" hidden_states.shape: {hidden_states.shape}")
|
||||
print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}")
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoAdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideo15TimeEmbedding(nn.Module):
|
||||
r"""
|
||||
Time embedding for HunyuanVideo 1.5.
|
||||
|
||||
Supports standard timestep embedding and optional reference timestep embedding
|
||||
for MeanFlow-based super-resolution models.
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`):
|
||||
The dimension of the output embedding.
|
||||
use_meanflow (`bool`, defaults to `False`):
|
||||
Whether to support reference timestep embeddings for temporal consistency.
|
||||
Set to `True` for super-resolution models.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
use_meanflow: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_meanflow = use_meanflow
|
||||
|
||||
self.time_proj_r = None
|
||||
self.timestep_embedder_r = None
|
||||
if use_meanflow:
|
||||
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
|
||||
|
||||
if timestep_r is not None:
|
||||
timesteps_proj_r = self.time_proj_r(timestep_r)
|
||||
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
|
||||
timesteps_emb = timesteps_emb + timesteps_emb_r
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
||||
|
||||
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
gate_msa, gate_mlp = self.norm_out(temb)
|
||||
hidden_states = hidden_states + attn_output * gate_msa
|
||||
|
||||
ff_output = self.ff(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.refiner_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoIndividualTokenRefinerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
# YiYi TODO convert 1D mask to 4d Bx1xLxL
|
||||
batch_size = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
attention_mask = attention_mask.to(hidden_states.device).bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
|
||||
for block in self.refiner_blocks:
|
||||
hidden_states = block(hidden_states, temb, self_attn_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
||||
)
|
||||
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_layers=num_layers,
|
||||
mlp_width_ratio=mlp_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_projections = hidden_states.mean(dim=1)
|
||||
else:
|
||||
original_dtype = hidden_states.dtype
|
||||
mask_float = attention_mask.float().unsqueeze(-1)
|
||||
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
pooled_projections = pooled_projections.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}")
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
print(f" temb: {temb.shape}, {temb[0,:10]}")
|
||||
print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}")
|
||||
print(f" -> token_refiner")
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
||||
|
||||
axes_grids = []
|
||||
for i in range(3):
|
||||
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
||||
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
||||
# differences in layerwise debugging outputs, but visually it is the same.
|
||||
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
||||
axes_grids.append(grid)
|
||||
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
||||
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
||||
|
||||
freqs = []
|
||||
for i in range(3):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_hunyuanimage.HunyuanImageByT5TextProjection
|
||||
class HunyuanVideo15ByT5TextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, out_features: int):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.linear_1 = nn.Linear(in_features, hidden_size)
|
||||
self.linear_2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.linear_3 = nn.Linear(hidden_size, out_features)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(encoder_hidden_states)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_3(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15ImageProjection(nn.Module):
|
||||
def __init__(self, in_channels: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.norm_in = nn.LayerNorm(in_channels)
|
||||
self.linear_1 = nn.Linear(in_channels, in_channels)
|
||||
self.act_fn = nn.GELU()
|
||||
self.linear_2 = nn.Linear(in_channels, hidden_size)
|
||||
self.norm_out = nn.LayerNorm(hidden_size)
|
||||
|
||||
def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm_in(image_embeds)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}")
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}")
|
||||
print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}")
|
||||
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}")
|
||||
print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
|
||||
_no_split_modules = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_attention_heads: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 20,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: int = 1,
|
||||
patch_size_t: int = 1,
|
||||
qk_norm: str = "rms_norm",
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: int = 1472,
|
||||
image_embed_dim: int = 1152,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
||||
self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim)
|
||||
|
||||
self.context_embedder = HunyuanVideoTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
|
||||
|
||||
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow)
|
||||
|
||||
self.cond_type_embed = nn.Embedding(3, inner_dim)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask_2: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
timestep_r: Optional[torch.LongTensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
temb = self.time_embed(timestep, timestep_r=timestep_r)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# qwen text embedding
|
||||
print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
print(f" timestep: {timestep}, {timestep[:10]}")
|
||||
print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}")
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
encoder_hidden_states_cond_emb = self.cond_type_embed(
|
||||
torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb
|
||||
print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
|
||||
# byt5 text embedding
|
||||
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
|
||||
print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
|
||||
|
||||
encoder_hidden_states_2_cond_emb = self.cond_type_embed(
|
||||
torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
|
||||
)
|
||||
encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
|
||||
print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}")
|
||||
|
||||
# image embed
|
||||
encoder_hidden_states_3 = self.image_embedder(image_embeds)
|
||||
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
is_t2v = torch.all(image_embeds == 0)
|
||||
if is_t2v:
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0
|
||||
encoder_attention_mask_3 = torch.zeros(
|
||||
(batch_size, encoder_hidden_states_3.shape[1]),
|
||||
dtype=encoder_attention_mask.dtype,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}")
|
||||
else:
|
||||
encoder_attention_mask_3 = torch.ones(
|
||||
(batch_size, encoder_hidden_states_3.shape[1]),
|
||||
dtype=encoder_attention_mask.dtype,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
encoder_hidden_states_3_cond_emb = self.cond_type_embed(
|
||||
2 * torch.ones_like(
|
||||
encoder_hidden_states_3[:, :, 0],
|
||||
dtype=torch.long,
|
||||
)
|
||||
)
|
||||
encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb
|
||||
|
||||
print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}")
|
||||
|
||||
|
||||
# reorder and combine text tokens: combine valid tokens first, then padding
|
||||
encoder_attention_mask = encoder_attention_mask.bool()
|
||||
encoder_attention_mask_2 = encoder_attention_mask_2.bool()
|
||||
encoder_attention_mask_3 = encoder_attention_mask_3.bool()
|
||||
new_encoder_hidden_states = []
|
||||
new_encoder_attention_mask = []
|
||||
|
||||
for text, text_mask, text_2, text_mask_2, image, image_mask in zip(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
encoder_hidden_states_2,
|
||||
encoder_attention_mask_2,
|
||||
encoder_hidden_states_3,
|
||||
encoder_attention_mask_3,
|
||||
):
|
||||
# Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm]
|
||||
new_encoder_hidden_states.append(
|
||||
torch.cat(
|
||||
[
|
||||
image[image_mask], # valid image
|
||||
text_2[text_mask_2], # valid byt5
|
||||
text[text_mask], # valid mllm
|
||||
image[~image_mask], # invalid image
|
||||
torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed)
|
||||
torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply same reordering to attention masks
|
||||
new_encoder_attention_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
image_mask[image_mask],
|
||||
text_mask_2[text_mask_2],
|
||||
text_mask[text_mask],
|
||||
image_mask[~image_mask],
|
||||
text_mask_2[~text_mask_2],
|
||||
text_mask[~text_mask],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
encoder_hidden_states = torch.stack(new_encoder_hidden_states)
|
||||
encoder_attention_mask = torch.stack(new_encoder_attention_mask)
|
||||
|
||||
print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}")
|
||||
print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}")
|
||||
print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}")
|
||||
print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}")
|
||||
print(f" temb.shape: {temb.shape}, {temb[0,:10]}")
|
||||
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
Reference in New Issue
Block a user