diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index bc43ba83cf..6b402d9a4f 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -405,20 +405,21 @@ def _get_checkpoint_shard_files( ignore_patterns = ["*.json", "*.md"] # `model_info` call must guarded with the above condition. local = False - try: - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - except ConnectionError as e: - if local_files_only: - temp_dir = snapshot_download( - repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only - ) - model_files_info = _get_filepaths_for_folder(temp_dir) - local = True - else: + if local_files_only: + temp_dir = snapshot_download( + repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only + ) + model_files_info = _get_filepaths_for_folder(temp_dir) + local = True + else: + try: + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + except ConnectionError as e: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" " again after checking your internet connection." ) from e + for shard_file in original_shard_filenames: if local: shard_file_present = any(shard_file in k for k in model_files_info)