1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Conversion] Support convert diffusers to safetensors (#1996)

fix: support diffusers to safetensors
This commit is contained in:
蓝色的秋风
2023-01-16 19:58:01 +08:00
committed by GitHub
parent cc2cc00d20
commit 651c5adf8a

View File

@@ -8,6 +8,8 @@ import re
import torch
from safetensors.torch import save_file
# =================#
# UNet Conversion #
@@ -266,6 +268,9 @@ if __name__ == "__main__":
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
)
args = parser.parse_args()
@@ -306,5 +311,9 @@ if __name__ == "__main__":
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if args.half:
state_dict = {k: v.half() for k, v in state_dict.items()}
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)
if args.use_safetensors:
save_file(state_dict, args.checkpoint_path)
else:
state_dict = {"state_dict": state_dict}
torch.save(state_dict, args.checkpoint_path)