1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

update conversion script: remove dtype, always keep same precision as original checkpoint

This commit is contained in:
yiyi@huggingface.co
2025-11-27 05:24:49 +00:00
parent 2f6914d57a
commit a0b2fe02b0

View File

@@ -1,19 +1,17 @@
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
--output_path /fsx/yiyi/hy15/480p_i2v\
--transformer_type 480p_i2v \
--dtype fp32
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\
--transformer_type 480p_t2v
"""
"""
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
--dtype bf16 \
--save_pipeline \
--byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
--transformer_type 480p_i2v
--transformer_type 480p_t2v
"""
import argparse
@@ -51,12 +49,33 @@ TRANSFORMER_CONFIGS = {
"rope_axes_dim": (16, 56, 56),
"use_meanflow": False,
},
"480p_t2v": {
"in_channels": 65,
"out_channels": 32,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 54,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 1,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"text_embed_dim": 3584,
"text_embed_2_dim": 1472,
"image_embed_dim": 1152,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"use_meanflow": False,
},
}
SCHEDULER_CONFIGS = {
"480p_i2v": {
"shift": 5.0,
},
"480p_t2v": {
"shift": 5.0,
},
}
GUIDANCE_CONFIGS = {
@@ -64,6 +83,10 @@ GUIDANCE_CONFIGS = {
"guidance_scale": 6.0,
"embedded_guidance_scale": None,
},
"480p_t2v": {
"guidance_scale": 6.0,
"embedded_guidance_scale": None,
},
}
@@ -555,6 +578,7 @@ def load_original_transformer_state_dict(args):
model_dir = model_dir / "transformer" / args.transformer_type
return load_sharded_safetensors(model_dir)
def load_original_vae_state_dict(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(
@@ -570,6 +594,7 @@ def load_original_vae_state_dict(args):
original_state_dict = load_file(ckpt_path)
return original_state_dict
def convert_transformer(args):
original_state_dict = load_original_transformer_state_dict(args)
@@ -581,6 +606,7 @@ def convert_transformer(args):
return transformer
def convert_vae(args):
original_state_dict = load_original_vae_state_dict(args)
with init_empty_weights():
@@ -591,7 +617,7 @@ def convert_vae(args):
def load_mllm():
print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct")
text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True)
text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16,low_cpu_mem_usage=True)
if hasattr(text_encoder, 'language_model'):
text_encoder = text_encoder.language_model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
@@ -641,8 +667,6 @@ def add_special_token(
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
def load_byt5(args):
"""
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
@@ -714,7 +738,6 @@ def get_args():
)
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument(
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
)
@@ -736,13 +759,6 @@ def get_args():
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
@@ -750,22 +766,16 @@ if __name__ == "__main__":
raise ValueError("Please provide --byt5_path when saving pipeline")
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True)
else:
vae = convert_vae(args)
vae = vae.to(dtype=dtype)
text_encoder, tokenizer = load_mllm()
text_encoder_2, tokenizer_2 = load_byt5(args)
text_encoder = text_encoder.to(dtype=dtype)
text_encoder_2 = text_encoder_2.to(dtype=dtype)
flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)