From 651c5adf8a5c1ca2d4bac339dc2bca1e1264bd25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=93=9D=E8=89=B2=E7=9A=84=E7=A7=8B=E9=A3=8E?= <461249104@qq.com> Date: Mon, 16 Jan 2023 19:58:01 +0800 Subject: [PATCH] [Conversion] Support convert diffusers to safetensors (#1996) fix: support diffusers to safetensors --- ...onvert_diffusers_to_original_stable_diffusion.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index 11b4b873e7..ffcea8332f 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -8,6 +8,8 @@ import re import torch +from safetensors.torch import save_file + # =================# # UNet Conversion # @@ -266,6 +268,9 @@ if __name__ == "__main__": parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) args = parser.parse_args() @@ -306,5 +311,9 @@ if __name__ == "__main__": state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if args.half: state_dict = {k: v.half() for k, v in state_dict.items()} - state_dict = {"state_dict": state_dict} - torch.save(state_dict, args.checkpoint_path) + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path)