1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-30 11:11:25 +00:00
parent 94536262cb
commit afa62e6fa8

View File

@@ -75,14 +75,15 @@ def extract_pipeline_component_names(pipeline_class):
def check_valid_url(pretrained_model_link_or_path):
# remove huggingface url
# check if url prefix is valid
# remove huggingface url prefix from model path
has_valid_url_prefix = False
for prefix in VALID_URL_PREFIXES:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
has_valid_url_prefix = True
return has_valid_url_prefix
return has_valid_url_prefix, pretrained_model_link_or_path
def download_model_checkpoint(
@@ -306,7 +307,7 @@ class FromSingleFileMixin:
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path)
has_valid_url_prefix, pretrained_model_link_or_path = check_valid_url(pretrained_model_link_or_path)
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
@@ -314,7 +315,9 @@ class FromSingleFileMixin:
raise ValueError(
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
)
if not ckpt_path.is_file():
if ckpt_path.is_file():
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
else:
pretrained_model_link_or_path = download_model_checkpoint(
ckpt_path,
cache_dir=cache_dir,
@@ -325,8 +328,6 @@ class FromSingleFileMixin:
revision=revision,
)
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
else:
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21