mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -75,14 +75,15 @@ def extract_pipeline_component_names(pipeline_class):
|
||||
|
||||
|
||||
def check_valid_url(pretrained_model_link_or_path):
|
||||
# remove huggingface url
|
||||
# check if url prefix is valid
|
||||
# remove huggingface url prefix from model path
|
||||
has_valid_url_prefix = False
|
||||
for prefix in VALID_URL_PREFIXES:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
return has_valid_url_prefix
|
||||
return has_valid_url_prefix, pretrained_model_link_or_path
|
||||
|
||||
|
||||
def download_model_checkpoint(
|
||||
@@ -306,7 +307,7 @@ class FromSingleFileMixin:
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path)
|
||||
has_valid_url_prefix, pretrained_model_link_or_path = check_valid_url(pretrained_model_link_or_path)
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
@@ -314,7 +315,9 @@ class FromSingleFileMixin:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
|
||||
)
|
||||
if not ckpt_path.is_file():
|
||||
if ckpt_path.is_file():
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
else:
|
||||
pretrained_model_link_or_path = download_model_checkpoint(
|
||||
ckpt_path,
|
||||
cache_dir=cache_dir,
|
||||
@@ -325,8 +328,6 @@ class FromSingleFileMixin:
|
||||
revision=revision,
|
||||
)
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
else:
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||
|
||||
Reference in New Issue
Block a user