mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix setting fp16 dtype in AnimateDiff convert script. (#7127)
* update * update
This commit is contained in:
@@ -30,6 +30,7 @@ def get_args():
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--use_motion_mid_block", action="store_true")
|
||||
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
||||
parser.add_argument("--save_fp16", action="store_true")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -48,4 +49,6 @@ if __name__ == "__main__":
|
||||
# skip loading position embeddings
|
||||
adapter.load_state_dict(conv_state_dict, strict=False)
|
||||
adapter.save_pretrained(args.output_path)
|
||||
adapter.save_pretrained(args.output_path, variant="fp16", torch_dtype=torch.float16)
|
||||
|
||||
if args.save_fp16:
|
||||
adapter.to(torch.float16).save_pretrained(args.output_path, variant="fp16")
|
||||
|
||||
Reference in New Issue
Block a user