From 04cd2dc451e56381b0e11131d4923dc7ef29ecd7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Aug 2025 14:50:50 +0530 Subject: [PATCH] reviewer feedback. --- src/diffusers/utils/hub_utils.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index de29ec4122..0ddfcaee34 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -44,7 +44,6 @@ from huggingface_hub.utils import ( ) from packaging import version from requests import HTTPError -from requests.exceptions import ConnectionError from .. import __version__ from .constants import ( @@ -402,24 +401,6 @@ def _get_checkpoint_shard_files( allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns] ignore_patterns = ["*.json", "*.md"] - try: - temp_dir = snapshot_download( - repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only - ) - except ConnectionError as e: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e - - model_files_info = _get_filepaths_for_folder(temp_dir) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k for k in model_files_info) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) try: # Load from URL @@ -437,6 +418,15 @@ def _get_checkpoint_shard_files( if subfolder is not None: cached_folder = os.path.join(cached_folder, subfolder) + model_files_info = _get_filepaths_for_folder(cached_folder) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k for k in model_files_info) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. except HTTPError as e: