diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index 11b4b873e7..ffcea8332f 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -8,6 +8,8 @@ import re import torch +from safetensors.torch import save_file + # =================# # UNet Conversion # @@ -266,6 +268,9 @@ if __name__ == "__main__": parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--half", action="store_true", help="Save weights in half precision.") + parser.add_argument( + "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt." + ) args = parser.parse_args() @@ -306,5 +311,9 @@ if __name__ == "__main__": state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} if args.half: state_dict = {k: v.half() for k, v in state_dict.items()} - state_dict = {"state_dict": state_dict} - torch.save(state_dict, args.checkpoint_path) + + if args.use_safetensors: + save_file(state_dict, args.checkpoint_path) + else: + state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path)