1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[core] respect local_files_only=True when using sharded checkpoints (#12005)

* tighten compilation tests for quantization

* feat: model_info but local.

* up

* Revert "tighten compilation tests for quantization"

This reverts commit 8d431dc967.

* up

* reviewer feedback.

* reviewer feedback.

* up

* up

* empty

* update

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>
This commit is contained in:
Sayak Paul
2025-08-14 14:50:51 +05:30
committed by GitHub
parent 46a0c6aa82
commit 1b48db4c8f
2 changed files with 66 additions and 11 deletions

View File

@@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
# If the repo doesn't have the required shards, error out early even before downloading anything.
if not local_files_only:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)
try:
# Load from URL
@@ -437,6 +439,11 @@ def _get_checkpoint_shard_files(
) from e
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
for cached_file in cached_filenames:
if not os.path.isfile(cached_file):
raise EnvironmentError(
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
)
return cached_filenames, sharded_metadata

View File

@@ -36,12 +36,12 @@ import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from requests.exceptions import HTTPError
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -291,6 +291,54 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
def test_local_files_only_with_sharded_checkpoint(self):
repo_id = "hf-internal-testing/tiny-flux-sharded"
error_response = mock.Mock(
status_code=500,
headers={},
raise_for_status=mock.Mock(side_effect=HTTPError),
json=mock.Mock(return_value={}),
)
with tempfile.TemporaryDirectory() as tmpdir:
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
with mock.patch("requests.Session.get", return_value=error_response):
# Should fail with local_files_only=False (network required)
# We would make a network call with model_info
with self.assertRaises(OSError):
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
)
# Should succeed with local_files_only=True (uses cache)
# model_info call skipped
local_model = FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
"Model parameters don't match!"
)
# Remove a shard file
cached_shard_file = try_to_load_from_cache(
repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
)
os.remove(cached_shard_file)
# Attempting to load from cache should raise an error
with self.assertRaises(OSError) as context:
FluxTransformer2DModel.from_pretrained(
repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
)
# Verify error mentions the missing shard
error_msg = str(context.exception)
assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
f"Expected error about missing shard, got: {error_msg}"
)
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):