diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 3189c83602..bc4f06164c 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -154,6 +154,7 @@ if __name__ == "__main__": pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict=args.checkpoint_path, original_config_file=args.original_config_file, + config_files=args.config_files, image_size=args.image_size, prediction_type=args.prediction_type, model_type=args.pipeline_type, diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5fbf6b819f..45c866c1aa 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2099,6 +2099,7 @@ class FromSingleFileMixin: from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt original_config_file = kwargs.pop("original_config_file", None) + config_files = kwargs.pop("config_files", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) @@ -2216,6 +2217,7 @@ class FromSingleFileMixin: vae=vae, tokenizer=tokenizer, original_config_file=original_config_file, + config_files=config_files, ) if torch_dtype is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 22b424bf16..618ee19422 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt( key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + config_url = None # model_type = "v1" - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + if config_files is not None and "v1" in config_files: + original_config_file = config_files["v1"] + else: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" - config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - + if config_files is not None and "v2" in config_files: + original_config_file = config_files["v2"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True elif key_name_sd_xl_base in checkpoint: # only base xl has two text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + if config_files is not None and "xl" in config_files: + original_config_file = config_files["xl"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" elif key_name_sd_xl_refiner in checkpoint: # only refiner xl has embedder and one text embedders - config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" - - original_config_file = BytesIO(requests.get(config_url).content) + if config_files is not None and "xl_refiner" in config_files: + original_config_file = config_files["xl_refiner"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + if config_url is not None: + original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file)