1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-08-12 20:26:54 +05:30
parent 71843a0c8b
commit fb2397f3fe

View File

@@ -405,20 +405,21 @@ def _get_checkpoint_shard_files(
ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
local = False
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
except ConnectionError as e:
if local_files_only:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
model_files_info = _get_filepaths_for_folder(temp_dir)
local = True
else:
if local_files_only:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
model_files_info = _get_filepaths_for_folder(temp_dir)
local = True
else:
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
except ConnectionError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
for shard_file in original_shard_filenames:
if local:
shard_file_present = any(shard_file in k for k in model_files_info)