mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Conversion] Support convert diffusers to safetensors (#1996)
fix: support diffusers to safetensors
This commit is contained in:
@@ -8,6 +8,8 @@ import re
|
||||
|
||||
import torch
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
# =================#
|
||||
# UNet Conversion #
|
||||
@@ -266,6 +268,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||
parser.add_argument(
|
||||
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -306,5 +311,9 @@ if __name__ == "__main__":
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
||||
if args.half:
|
||||
state_dict = {k: v.half() for k, v in state_dict.items()}
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, args.checkpoint_path)
|
||||
|
||||
if args.use_safetensors:
|
||||
save_file(state_dict, args.checkpoint_path)
|
||||
else:
|
||||
state_dict = {"state_dict": state_dict}
|
||||
torch.save(state_dict, args.checkpoint_path)
|
||||
|
||||
Reference in New Issue
Block a user