From 1c528a4166bcb6e10ee0152f249638b33870ee3f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Aug 2025 14:55:18 +0530 Subject: [PATCH] up --- src/diffusers/utils/hub_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 0ddfcaee34..423584fd92 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -30,6 +30,7 @@ from huggingface_hub import ( ModelCardData, create_repo, hf_hub_download, + model_info, snapshot_download, upload_folder, ) @@ -402,6 +403,23 @@ def _get_checkpoint_shard_files( ignore_patterns = ["*.json", "*.md"] + # If the repo doesn't have the required shards, error out early even before downloading anything. + if not local_files_only: + try: + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + except ConnectionError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e + try: # Load from URL cached_folder = snapshot_download(