mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update conversion script
This commit is contained in:
@@ -1,16 +1,19 @@
|
||||
"""
|
||||
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
|
||||
--original_state_dict_folder /raid/yiyi/new-model-vid \
|
||||
--output_transformer_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \
|
||||
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
|
||||
--output_path /fsx/yiyi/hy15/480p_i2v\
|
||||
--transformer_type 480p_i2v \
|
||||
--dtype fp32
|
||||
"""
|
||||
|
||||
"""
|
||||
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
|
||||
--original_state_dict_folder /raid/yiyi/new-model-vid \
|
||||
--output_vae_path /raid/yiyi/hunyuanvideo15-vae \
|
||||
--dtype fp32
|
||||
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
|
||||
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
|
||||
--dtype bf16 \
|
||||
--save_pipeline \
|
||||
--byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
|
||||
--transformer_type 480p_i2v
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -22,11 +25,12 @@ from safetensors.torch import load_file
|
||||
from huggingface_hub import snapshot_download, hf_hub_download
|
||||
|
||||
import pathlib
|
||||
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15
|
||||
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline
|
||||
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"480p_i2v": {
|
||||
@@ -49,6 +53,20 @@ TRANSFORMER_CONFIGS = {
|
||||
},
|
||||
}
|
||||
|
||||
SCHEDULER_CONFIGS = {
|
||||
"480p_i2v": {
|
||||
"shift": 5.0,
|
||||
},
|
||||
}
|
||||
|
||||
GUIDANCE_CONFIGS = {
|
||||
"480p_i2v": {
|
||||
"guidance_scale": 6.0,
|
||||
"embedded_guidance_scale": None,
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
@@ -571,18 +589,16 @@ def convert_vae(args):
|
||||
vae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
def save_text_encoder(output_path):
|
||||
def load_mllm():
|
||||
print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True)
|
||||
if hasattr(text_encoder, 'language_model'):
|
||||
text_encoder = text_encoder.language_model
|
||||
|
||||
|
||||
text_encoder.save_pretrained(output_path + "/text_encoder")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
|
||||
tokenizer.save_pretrained(output_path + "/tokenizer")
|
||||
return text_encoder, tokenizer
|
||||
|
||||
|
||||
#copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
|
||||
def add_special_token(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
@@ -625,42 +641,36 @@ def add_special_token(
|
||||
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
||||
|
||||
|
||||
def save_text_encoder_2(
|
||||
byt5_base_path,
|
||||
byt5_checkpoint_path,
|
||||
color_ann_path,
|
||||
font_ann_path,
|
||||
output_path,
|
||||
multilingual=True
|
||||
):
|
||||
|
||||
|
||||
def load_byt5(args):
|
||||
"""
|
||||
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
|
||||
|
||||
Args:
|
||||
byt5_base_path: Path to base byt5-small model (e.g., "google/byt5-small")
|
||||
byt5_checkpoint_path: Path to Glyph-SDXL-v2 checkpoint (byt5_model.pt)
|
||||
color_ann_path: Path to color_idx.json
|
||||
font_ann_path: Path to multilingual_10-lang_idx.json
|
||||
output_path: Where to save the converted model
|
||||
multilingual: Whether to use multilingual font tokens
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# 1. Load base tokenizer and encoder
|
||||
tokenizer = AutoTokenizer.from_pretrained(byt5_base_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
|
||||
|
||||
# Load as T5EncoderModel
|
||||
encoder = T5EncoderModel.from_pretrained(byt5_base_path)
|
||||
encoder = T5EncoderModel.from_pretrained("google/byt5-small")
|
||||
|
||||
byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
|
||||
color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
|
||||
font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
|
||||
|
||||
# 2. Add special tokens
|
||||
add_special_token(
|
||||
tokenizer,
|
||||
encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=encoder,
|
||||
add_color=True,
|
||||
add_font=True,
|
||||
color_ann_path=color_ann_path,
|
||||
font_ann_path=font_ann_path,
|
||||
multilingual=multilingual
|
||||
multilingual=True,
|
||||
)
|
||||
|
||||
|
||||
# 3. Load Glyph-SDXL-v2 checkpoint
|
||||
print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
|
||||
checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu')
|
||||
@@ -694,11 +704,7 @@ def save_text_encoder_2(
|
||||
raise ValueError(f"Missing keys: {missing_keys}")
|
||||
|
||||
|
||||
# Save encoder
|
||||
encoder.save_pretrained(output_path + "/text_encoder_2")
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer.save_pretrained(output_path + "/tokenizer_2")
|
||||
return encoder, tokenizer
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -707,12 +713,26 @@ def get_args():
|
||||
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
|
||||
)
|
||||
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict")
|
||||
parser.add_argument("--output_vae_path", type=str, default=None, help="Path where converted VAE should be saved")
|
||||
parser.add_argument("--output_transformer_path", type=str, default=None, help="Path where converted transformer should be saved")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
parser.add_argument(
|
||||
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
|
||||
)
|
||||
parser.add_argument(
|
||||
"--byt5_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"path to the downloaded byt5 checkpoint & assets. "
|
||||
"Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: "
|
||||
"`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` "
|
||||
"or manually download following the instructions on "
|
||||
"https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. "
|
||||
"The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, "
|
||||
"like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -726,16 +746,44 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if args.save_pipeline and args.byt5_path is None:
|
||||
raise ValueError("Please provide --byt5_path when saving pipeline")
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.output_transformer_path is not None:
|
||||
transformer = convert_transformer(args)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
transformer.save_pretrained(args.output_transformer_path, safe_serialization=True)
|
||||
transformer = convert_transformer(args)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True)
|
||||
else:
|
||||
|
||||
if args.output_vae_path is not None:
|
||||
vae = convert_vae(args)
|
||||
vae = vae.to(dtype=dtype)
|
||||
vae.save_pretrained(args.output_vae_path, safe_serialization=True)
|
||||
|
||||
|
||||
text_encoder, tokenizer = load_mllm()
|
||||
text_encoder_2, tokenizer_2 = load_byt5(args)
|
||||
text_encoder = text_encoder.to(dtype=dtype)
|
||||
text_encoder_2 = text_encoder_2.to(dtype=dtype)
|
||||
|
||||
flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
|
||||
guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
|
||||
guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
pipeline = HunyuanVideo15Pipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
guider=guider,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user