1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

reviewer feedback.

This commit is contained in:
sayakpaul
2025-08-13 14:50:50 +05:30
parent b7af5111c4
commit 04cd2dc451

View File

@@ -44,7 +44,6 @@ from huggingface_hub.utils import (
)
from packaging import version
from requests import HTTPError
from requests.exceptions import ConnectionError
from .. import __version__
from .constants import (
@@ -402,24 +401,6 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
try:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
except ConnectionError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
model_files_info = _get_filepaths_for_folder(temp_dir)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k for k in model_files_info)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
@@ -437,6 +418,15 @@ def _get_checkpoint_shard_files(
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)
model_files_info = _get_filepaths_for_folder(cached_folder)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k for k in model_files_info)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e: