mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add pipeline_class_name argument to Stable Diffusion conversion script (#4461)
* add pipeline class * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
|
||||
@@ -133,8 +134,22 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pipeline_class_name",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Specify the pipeline class name",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pipeline_class_name is not None:
|
||||
library = importlib.import_module("diffusers")
|
||||
class_obj = getattr(library, args.pipeline_class_name)
|
||||
else:
|
||||
pipeline_class = None
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
@@ -152,6 +167,7 @@ if __name__ == "__main__":
|
||||
clip_stats_path=args.clip_stats_path,
|
||||
controlnet=args.controlnet,
|
||||
vae_path=args.vae_path,
|
||||
pipeline_class=pipeline_class,
|
||||
)
|
||||
|
||||
if args.half:
|
||||
|
||||
Reference in New Issue
Block a user