mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
allow loading of sd models from safetensors without online lookups using local config files (#5019)
finish config_files implementation
This commit is contained in:
@@ -154,6 +154,7 @@ if __name__ == "__main__":
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
config_files=args.config_files,
|
||||
image_size=args.image_size,
|
||||
prediction_type=args.prediction_type,
|
||||
model_type=args.pipeline_type,
|
||||
|
||||
@@ -2099,6 +2099,7 @@ class FromSingleFileMixin:
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -2216,6 +2217,7 @@ class FromSingleFileMixin:
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
|
||||
@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
config_url = None
|
||||
|
||||
# model_type = "v1"
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
if config_files is not None and "v1" in config_files:
|
||||
original_config_file = config_files["v1"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
|
||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
|
||||
if config_files is not None and "v2" in config_files:
|
||||
original_config_file = config_files["v2"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif key_name_sd_xl_base in checkpoint:
|
||||
# only base xl has two text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
if config_files is not None and "xl" in config_files:
|
||||
original_config_file = config_files["xl"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
elif key_name_sd_xl_refiner in checkpoint:
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
if config_files is not None and "xl_refiner" in config_files:
|
||||
original_config_file = config_files["xl_refiner"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user