From 8b7ea8110bfdfcea4097596ca34ab849340a4d8a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Nov 2025 22:13:15 +0100 Subject: [PATCH] add --- .../convert_hunyuan_video1_5_to_diffusers.py | 382 +++++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_hunyuanvideo15.py | 968 ++++++++++++++++++ src/diffusers/models/transformers/__init__.py | 1 + .../transformer_hunyuan_video15.py | 937 +++++++++++++++++ 7 files changed, 2297 insertions(+) create mode 100644 scripts/convert_hunyuan_video1_5_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video15.py diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py new file mode 100644 index 0000000000..59df02d190 --- /dev/null +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -0,0 +1,382 @@ +""" +python scripts/convert_hunyuan_video1_5_to_diffusers.py \ + --original_state_dict_folder /raid/yiyi/new-model-vid \ + --output_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \ + --transformer_type 480p_i2v \ + --dtype fp32 +""" + +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +from huggingface_hub import snapshot_download + +import pathlib +from diffusers import HunyuanVideo15Transformer3DModel + +TRANSFORMER_CONFIGS = { + "480p_i2v": { + "in_channels": 65, + "out_channels": 32, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 54, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "text_embed_dim": 3584, + "text_embed_2_dim": 1472, + "image_embed_dim": 1152, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "use_meanflow": False, + }, +} + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_hyvideo15_transformer_to_diffusers(original_state_dict): + """ + Convert HunyuanVideo 1.5 original checkpoint to Diffusers format. + """ + converted_state_dict = {} + + # 1. time_embed.timestep_embedder <- time_in + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_in.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_in.mlp.0.bias" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_in.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_in.mlp.2.bias" + ) + + # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.0.bias") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.2.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.2.bias") + ) + + # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_1.weight") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_1.bias") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_2.weight") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_2.bias") + ) + + # 4. context_embedder.proj_in <- txt_in.input_embedder + converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop( + "txt_in.input_embedder.weight" + ) + converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias") + + # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner + num_refiner_blocks = 2 + for i in range(num_refiner_blocks): + block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}." + orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}." + + # norm1 + converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight") + converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias") + + # Split self_attn_qkv into to_q, to_k, to_v + qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight") + qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias") + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) + + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias + + # self_attn_proj -> attn.to_out.0 + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.bias" + ) + + # norm2 + converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight") + converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias") + + # mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias") + + # adaLN_modulation -> norm_out + converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.weight" + ) + converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.bias" + ) + + # 6. context_embedder_2 <- byt5_in + converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight") + converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias") + converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight") + converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias") + converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight") + converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias") + converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight") + converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias") + + # 7. image_embedder <- vision_in + converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight") + converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias") + converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight") + converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias") + converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight") + converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias") + converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight") + converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias") + + # 8. x_embedder <- img_in + converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight") + converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias") + + # 9. cond_type_embed <- cond_type_embedding + converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight") + + # 10. transformer_blocks <- double_blocks + num_layers = 54 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + orig_prefix = f"double_blocks.{i}." + + # norm1 (img_mod) + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.bias" + ) + + # norm1_context (txt_mod) + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.bias" + ) + + # img attention (to_q, to_k, to_v) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.bias" + ) + + # img attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k_norm.weight" + ) + + # img attention output projection + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.bias" + ) + + # txt attention (add_q_proj, add_k_proj, add_v_proj) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.bias" + ) + + # txt attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k_norm.weight" + ) + + # txt attention output projection + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.bias" + ) + + # norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False) + # So we skip them + + # img_mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.bias" + ) + + # txt_mlp -> ff_context + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.bias" + ) + + # 11. norm_out and proj_out <- final_layer + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop( + "final_layer.adaLN_modulation.1.weight" + )) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias")) + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + return converted_state_dict + + +def load_sharded_safetensors(dir: pathlib.Path): + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + state_dict = {} + for path in file_paths: + state_dict.update(load_file(path)) + return state_dict + + +def load_original_state_dict(args): + if args.original_state_dict_repo_id is not None: + model_dir = snapshot_download( + args.original_state_dict_repo_id, + repo_type="model", + allow_patterns="transformer/" + args.transformer_type + "/*" + ) + elif args.original_state_dict_folder is not None: + model_dir = pathlib.Path(args.original_state_dict_folder) + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`") + model_dir = pathlib.Path(model_dir) + model_dir = model_dir / "transformer" / args.transformer_type + return load_sharded_safetensors(model_dir) + +def convert_transformer(args): + original_state_dict = load_original_state_dict(args) + + config = TRANSFORMER_CONFIGS[args.transformer_type] + with init_empty_weights(): + transformer = HunyuanVideo15Transformer3DModel(**config) + state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict) + transformer.load_state_dict(state_dict, strict=True, assign=True) + + return transformer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" + ) + parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Folder name of the original state dict") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument( + "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) + ) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + transformer = convert_transformer(args) + transformer = transformer.to(dtype=dtype) + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cd7a2cb581..eb9929cf2c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -189,6 +189,7 @@ else: "AutoencoderKLHunyuanImage", "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", + "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -223,6 +224,7 @@ else: "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", @@ -903,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -939,6 +942,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: HunyuanImageTransformer2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + HunyuanVideo15Transformer3DModel, I2VGenXLUNet, Kandinsky3UNet, Kandinsky5Transformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b42e981f71..aa80c93f2b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,6 +38,7 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] + _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] @@ -83,6 +84,7 @@ if is_torch_available(): _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] @@ -143,6 +145,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -191,6 +194,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: DualTransformer2DModel, EasyAnimateTransformer3DModel, FluxTransformer2DModel, + HunyuanVideo15Transformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index edfaabb070..6d39785a86 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -7,6 +7,7 @@ from .autoencoder_kl_cosmos import AutoencoderKLCosmos from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner +from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py new file mode 100644 index 0000000000..2f05172a97 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -0,0 +1,968 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideo15RMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class HunyuanVideo15AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = HunyuanVideo15RMS_norm(in_channels, images=False) + + self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1) + + @staticmethod + def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + """Prepare a causal attention mask for 3D videos. + + Args: + n_frame (int): Number of frames (temporal length). + n_hw (int): Product of height and width. + dtype: Desired mask dtype. + device: Device for the mask. + batch_size (int, optional): If set, expands for batch. + + Returns: + torch.Tensor: Causal attention mask. + """ + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + x = self.norm(x) + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, frames, height, width = query.shape + + query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + + attention_mask = self.prepare_causal_attention_mask(frames, height * width, query.dtype, query.device, batch_size=batch_size) + + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + # batch_size, 1, frames * height * width, channels + + x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3) + x = self.proj_out(x) + + return x + identity + + +class HunyuanVideo15Upsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + @staticmethod + def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w) + + Args: + tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w) + r1: temporal upsampling factor + r2: height upsampling factor + r3: width upsampling factor + """ + b, packed_c, f, h, w = tensor.shape + factor = r1 * r2 * r3 + c = packed_c // factor + + tensor = tensor.view(b, r1, r2, r3, c, f, h, w) + tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3) + return tensor.reshape(b, c, f * r1, h * r2, w * r3) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = h_first[:, : h_first.shape[1] // 2] + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2) + x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) + + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2) + x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = torch.cat([x_first, x_next], dim=2) + + else: + h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2) + return h + shortcut + + +class HunyuanVideo15Downsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + assert out_channels % factor == 0 + # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + @staticmethod + def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w) + + This packs spatial/temporal dimensions into channels (opposite of upsample) + """ + b, c, packed_f, packed_h, packed_w = tensor.shape + f, h, w = packed_f // r1, packed_h // r2, packed_w // r3 + + tensor = tensor.view(b, c, f, r1, h, r2, w, r3) + tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6) + return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = torch.cat([h_first, h_first], dim=1) + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2) + B, C, T, H, W = x_first.shape + x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2) + B, C, T, H, W = x_next.shape + x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + shortcut = torch.cat([x_first, x_next], dim=2) + else: + h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class HunyuanVideo15ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False) + self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False) + self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class HunyuanVideo15MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + add_attention: bool = True, + ) -> None: + super().__init__() + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append(HunyuanVideo15AttnBlock(in_channels)) + else: + attentions.append(None) + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideo15DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample_out_channels: Optional[int] = None, + add_temporal_downsample: int = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if downsample_out_channels is not None: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideo15Downsample( + out_channels, + out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample_out_channels: Optional[int] = None, + add_temporal_upsample: bool = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if upsample_out_channels is not None: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideo15Upsample( + out_channels, + out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15Encoder3D(nn.Module): + r""" + 3D vae encoder for HunyuanImageRefiner. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 16, + downsample_match_channel: bool = True, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.group_size = block_out_channels[-1] // self.out_channels + + self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + add_spatial_downsample = i < np.log2(spatial_compression_ratio) + output_channel = block_out_channels[i] + if not add_spatial_downsample: + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=None, + add_temporal_downsample=False, + ) + input_channel = output_channel + else: + add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio) + downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + input_channel = downsample_out_channels + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1]) + + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2) + batch_size, _, frame, height, width = hidden_states.shape + short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states += short_cut + + return hidden_states + + +class HunyuanVideo15Decoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + upsample_match_channel: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.in_channels = in_channels + self.out_channels = out_channels + self.repeat = block_out_channels[0] // self.in_channels + + self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0]) + + # up + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + add_spatial_upsample = i < np.log2(spatial_compression_ratio) + add_temporal_upsample = i < np.log2(temporal_compression_ratio) + if add_spatial_upsample or add_temporal_upsample: + upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + input_channel = upsample_out_channels + else: + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=None, + add_temporal_upsample=False, + ) + input_channel = output_channel + + self.up_blocks.append(up_block) + + # out + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for + HunyuanVideo-1.5. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 32, + block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + scaling_factor: float = 1.03682, + ) -> None: + super().__init__() + + self.encoder = HunyuanVideo15Encoder3D( + in_channels=in_channels, + out_channels=latent_channels * 2, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanVideo15Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal tile height and width in latent space + self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_latent_min_height: Optional[int] = None, + tile_latent_min_width: Optional[int] = None, + tile_overlap_factor: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_latent_min_height (`int`, *optional*): + The minimum height required for a latent to be separated into tiles across the height dimension. + tile_latent_min_width (`int`, *optional*): + The minimum width required for a latent to be separated into tiles across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height + self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z) + + dec = self.decoder(z) + + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2 + blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6 + row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + _, _, _, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64 + blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192 + row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 826469237f..ee408acd23 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -44,3 +44,4 @@ if is_torch_available(): from .transformer_wan import WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel + from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py new file mode 100644 index 0000000000..8fb2ea451a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -0,0 +1,937 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.loaders import FromOriginalModelMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin +from ..embeddings import ( + CombinedTimestepTextProjEmbeddings, + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15AttnProcessor2_0: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + assert False # YiYi Notes: remove this condition if this code path is never used + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + else: + assert False + # YiYi Notes: remove this condition if this code path is never used + if attn.norm_k is not None: + key = attn.norm_k(key) + else: + assert False + # YiYi Notes: remove this condition if this code path is never used + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + assert False # YiYi Notes: remove this condition if this code path is never used + query = torch.cat( + [ + apply_rotary_emb( + query[:, : -encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + query[:, -encoder_hidden_states.shape[1] :], + ], + dim=1, + ) + key = torch.cat( + [ + apply_rotary_emb( + key[:, : -encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + key[:, -encoder_hidden_states.shape[1] :], + ], + dim=1, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + + else: + assert False # YiYi Notes: remove this condition if this code path is never used + + + batch_size, seq_len, heads, dim = query.shape + print(f" query.shape: {query.shape}") + print(f" attention_mask.shape: {attention_mask.shape}") + attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) + print(f" attention_mask.shape: {attention_mask.shape}") + attention_mask = attention_mask.bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + # 5. Attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + print(f" hidden_states.shape: {hidden_states.shape}") + print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}") + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideo15TimeEmbedding(nn.Module): + r""" + Time embedding for HunyuanVideo 1.5. + + Supports standard timestep embedding and optional reference timestep embedding + for MeanFlow-based super-resolution models. + + Args: + embedding_dim (`int`): + The dimension of the output embedding. + use_meanflow (`bool`, defaults to `False`): + Whether to support reference timestep embeddings for temporal consistency. + Set to `True` for super-resolution models. + """ + def __init__( + self, + embedding_dim: int, + use_meanflow: bool = False, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_meanflow = use_meanflow + + self.time_proj_r = None + self.timestep_embedder_r = None + if use_meanflow: + self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + + def forward( + self, + timestep: torch.Tensor, + timestep_r: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) + + if timestep_r is not None: + timesteps_proj_r = self.time_proj_r(timestep_r) + timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype)) + timesteps_emb = timesteps_emb + timesteps_emb_r + + return timesteps_emb + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + # YiYi TODO convert 1D mask to 4d Bx1xLxL + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}") + hidden_states = self.proj_in(hidden_states) + print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}") + print(f" temb: {temb.shape}, {temb[0,:10]}") + print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}") + print(f" -> token_refiner") + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}") + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +# Copied from diffusers.models.transformers.transformer_hunyuanimage.HunyuanImageByT5TextProjection +class HunyuanVideo15ByT5TextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, out_features: int): + super().__init__() + self.norm = nn.LayerNorm(in_features) + self.linear_1 = nn.Linear(in_features, hidden_size) + self.linear_2 = nn.Linear(hidden_size, hidden_size) + self.linear_3 = nn.Linear(hidden_size, out_features) + self.act_fn = nn.GELU() + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(encoder_hidden_states) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_3(hidden_states) + return hidden_states + + +class HunyuanVideo15ImageProjection(nn.Module): + def __init__(self, in_channels: int, hidden_size: int): + super().__init__() + self.norm_in = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear(in_channels, in_channels) + self.act_fn = nn.GELU() + self.linear_2 = nn.Linear(in_channels, hidden_size) + self.norm_out = nn.LayerNorm(hidden_size) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm_in(image_embeds) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.norm_out(hidden_states) + return hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}") + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}") + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}") + print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}") + + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}") + print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + return hidden_states, encoder_hidden_states + + +class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + _repeated_blocks = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 1, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + text_embed_dim: int = 3584, + text_embed_2_dim: int = 1472, + image_embed_dim: int = 1152, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int, ...] = (16, 56, 56), + use_meanflow: bool = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim) + + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) + + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow) + + self.cond_type_embed = nn.Embedding(3, inner_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + encoder_attention_mask_2: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + timestep_r: Optional[torch.LongTensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_embed(timestep, timestep_r=timestep_r) + + hidden_states = self.x_embedder(hidden_states) + + # qwen text embedding + print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + print(f" timestep: {timestep}, {timestep[:10]}") + print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}") + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + encoder_hidden_states_cond_emb = self.cond_type_embed( + torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb + print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + # byt5 text embedding + encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2) + print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") + + encoder_hidden_states_2_cond_emb = self.cond_type_embed( + torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb + print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") + + # image embed + encoder_hidden_states_3 = self.image_embedder(image_embeds) + print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + is_t2v = torch.all(image_embeds == 0) + if is_t2v: + encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 + encoder_attention_mask_3 = torch.zeros( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}") + else: + encoder_attention_mask_3 = torch.ones( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + encoder_hidden_states_3_cond_emb = self.cond_type_embed( + 2 * torch.ones_like( + encoder_hidden_states_3[:, :, 0], + dtype=torch.long, + ) + ) + encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb + + print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + + + # reorder and combine text tokens: combine valid tokens first, then padding + encoder_attention_mask = encoder_attention_mask.bool() + encoder_attention_mask_2 = encoder_attention_mask_2.bool() + encoder_attention_mask_3 = encoder_attention_mask_3.bool() + new_encoder_hidden_states = [] + new_encoder_attention_mask = [] + + for text, text_mask, text_2, text_mask_2, image, image_mask in zip( + encoder_hidden_states, + encoder_attention_mask, + encoder_hidden_states_2, + encoder_attention_mask_2, + encoder_hidden_states_3, + encoder_attention_mask_3, + ): + # Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm] + new_encoder_hidden_states.append( + torch.cat( + [ + image[image_mask], # valid image + text_2[text_mask_2], # valid byt5 + text[text_mask], # valid mllm + image[~image_mask], # invalid image + torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed) + torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed) + ], + dim=0, + ) + ) + + # Apply same reordering to attention masks + new_encoder_attention_mask.append( + torch.cat( + [ + image_mask[image_mask], + text_mask_2[text_mask_2], + text_mask[text_mask], + image_mask[~image_mask], + text_mask_2[~text_mask_2], + text_mask[~text_mask], + ], + dim=0, + ) + ) + + encoder_hidden_states = torch.stack(new_encoder_hidden_states) + encoder_attention_mask = torch.stack(new_encoder_attention_mask) + + print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}") + print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}") + print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}") + print(f" temb.shape: {temb.shape}, {temb[0,:10]}") + + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states)