mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Safetensor loading in AnimateDiff conversion scripts (#7764)
* update * update
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def convert_motion_module(original_state_dict):
|
||||
@@ -34,7 +34,10 @@ def get_args():
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if args.ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
@@ -38,7 +39,11 @@ def get_args():
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if args.ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user