diff --git a/scripts/convert_animatediff_motion_lora_to_diffusers.py b/scripts/convert_animatediff_motion_lora_to_diffusers.py index 509a734579..c680fdc684 100644 --- a/scripts/convert_animatediff_motion_lora_to_diffusers.py +++ b/scripts/convert_animatediff_motion_lora_to_diffusers.py @@ -1,7 +1,7 @@ import argparse import torch -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file def convert_motion_module(original_state_dict): @@ -34,7 +34,10 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"] diff --git a/scripts/convert_animatediff_motion_module_to_diffusers.py b/scripts/convert_animatediff_motion_module_to_diffusers.py index ceb967acd3..e8fb007243 100644 --- a/scripts/convert_animatediff_motion_module_to_diffusers.py +++ b/scripts/convert_animatediff_motion_module_to_diffusers.py @@ -1,6 +1,7 @@ import argparse import torch +from safetensors.torch import load_file from diffusers import MotionAdapter @@ -38,7 +39,11 @@ def get_args(): if __name__ == "__main__": args = get_args() - state_dict = torch.load(args.ckpt_path, map_location="cpu") + if args.ckpt_path.endswith(".safetensors"): + state_dict = load_file(args.ckpt_path) + else: + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): state_dict = state_dict["state_dict"]