From 796c01534dcc8856aa81d69753a2df758274d625 Mon Sep 17 00:00:00 2001 From: Mystfit Date: Fri, 11 Aug 2023 17:41:43 +1200 Subject: [PATCH] Fixing repo_id regex validation error on windows platforms (#4358) * Fixing repo_id regex validation error on windows platforms * Validating correct URL with prefix is provided If we are loading a URL then we don't need to use os.path.join and array slicing to split out a repo_id and file path from an absolute filepath. Checking if the URL prefix is valid first before doing any URL splitting otherwise we raise a ValueError since neither a valid filepath or URL was provided. * Style fixes --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a2fe96e51c..0907cbfd16 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1877,16 +1877,24 @@ class FromSingleFileMixin: raise ValueError(f"Unhandled pipeline class: {pipeline_name}") # remove huggingface url - for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + has_valid_url_prefix = False + valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + for prefix in valid_url_prefixes: if pretrained_model_link_or_path.startswith(prefix): pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + has_valid_url_prefix = True # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) if not ckpt_path.is_file(): + if not has_valid_url_prefix: + raise ValueError( + f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}" + ) + # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = os.path.join(*ckpt_path.parts[:2]) - file_path = os.path.join(*ckpt_path.parts[2:]) + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) if file_path.startswith("blob/"): file_path = file_path[len("blob/") :]