diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 97fde62cdb..fcdf49156a 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -405,20 +405,14 @@ def _get_checkpoint_shard_files( # If the repo doesn't have the required shards, error out early even before downloading anything. if not local_files_only: - try: - model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - except ConnectionError as e: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) + for shard_file in original_shard_filenames: + shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) + if not shard_file_present: + raise EnvironmentError( + f"{shards_path} does not appear to have a file named {shard_file} which is " + "required according to the checkpoint index." + ) try: # Load from URL @@ -436,16 +430,6 @@ def _get_checkpoint_shard_files( if subfolder is not None: cached_folder = os.path.join(cached_folder, subfolder) - # Check again after downloading/loading from the cache. - model_files_info = _get_filepaths_for_folder(cached_folder) - for shard_file in original_shard_filenames: - shard_file_present = any(shard_file in k for k in model_files_info) - if not shard_file_present: - raise EnvironmentError( - f"{shards_path} does not appear to have a file named {shard_file} which is " - "required according to the checkpoint index." - ) - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # we don't have to catch them here. We have also dealt with EntryNotFoundError. except HTTPError as e: @@ -455,20 +439,15 @@ def _get_checkpoint_shard_files( ) from e cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames] + for cached_file in cached_filenames: + if not os.path.isfile(cached_file): + raise EnvironmentError( + f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index." + ) return cached_filenames, sharded_metadata -def _get_filepaths_for_folder(folder): - relative_paths = [] - for root, dirs, files in os.walk(folder): - for fname in files: - abs_path = os.path.join(root, fname) - rel_path = os.path.relpath(abs_path, start=folder) - relative_paths.append(rel_path) - return relative_paths - - def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): if filenames and folder: raise ValueError("Both `filenames` and `folder` cannot be provided.") diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0e16f95a42..1e08191f56 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -36,12 +36,12 @@ import safetensors.torch import torch import torch.nn as nn from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size -from huggingface_hub import ModelCard, delete_repo, snapshot_download +from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache from huggingface_hub.utils import is_jinja_available from parameterized import parameterized from requests.exceptions import HTTPError -from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel +from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, AttnProcessor2_0, @@ -291,6 +291,54 @@ class ModelUtilsTest(unittest.TestCase): if p1.data.ne(p2.data).sum() > 0: assert False, "Parameters not the same!" + def test_local_files_only_with_sharded_checkpoint(self): + repo_id = "hf-internal-testing/tiny-flux-sharded" + error_response = mock.Mock( + status_code=500, + headers={}, + raise_for_status=mock.Mock(side_effect=HTTPError), + json=mock.Mock(return_value={}), + ) + + with tempfile.TemporaryDirectory() as tmpdir: + model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir) + + with mock.patch("requests.Session.get", return_value=error_response): + # Should fail with local_files_only=False (network required) + # We would make a network call with model_info + with self.assertRaises(OSError): + FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False + ) + + # Should succeed with local_files_only=True (uses cache) + # model_info call skipped + local_model = FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True + ) + + assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), ( + "Model parameters don't match!" + ) + + # Remove a shard file + cached_shard_file = try_to_load_from_cache( + repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir + ) + os.remove(cached_shard_file) + + # Attempting to load from cache should raise an error + with self.assertRaises(OSError) as context: + FluxTransformer2DModel.from_pretrained( + repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True + ) + + # Verify error mentions the missing shard + error_msg = str(context.exception) + assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, ( + f"Expected error about missing shard, got: {error_msg}" + ) + @unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners") @unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.") def test_one_request_upon_cached(self):