diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 4ed845057f..82d29ce73f 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -75,14 +75,15 @@ def extract_pipeline_component_names(pipeline_class): def check_valid_url(pretrained_model_link_or_path): - # remove huggingface url + # check if url prefix is valid + # remove huggingface url prefix from model path has_valid_url_prefix = False for prefix in VALID_URL_PREFIXES: if pretrained_model_link_or_path.startswith(prefix): pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] has_valid_url_prefix = True - return has_valid_url_prefix + return has_valid_url_prefix, pretrained_model_link_or_path def download_model_checkpoint( @@ -306,7 +307,7 @@ class FromSingleFileMixin: if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path) + has_valid_url_prefix, pretrained_model_link_or_path = check_valid_url(pretrained_model_link_or_path) # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) @@ -314,7 +315,9 @@ class FromSingleFileMixin: raise ValueError( f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}" ) - if not ckpt_path.is_file(): + if ckpt_path.is_file(): + checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) + else: pretrained_model_link_or_path = download_model_checkpoint( ckpt_path, cache_dir=cache_dir, @@ -325,8 +328,6 @@ class FromSingleFileMixin: revision=revision, ) checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) - else: - checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21