1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

feat: model_info but local.

This commit is contained in:
sayakpaul
2025-07-28 15:16:53 +05:30
parent 8d431dc967
commit 69920eff3e

View File

@@ -404,9 +404,21 @@ def _get_checkpoint_shard_files(
ignore_patterns = ["*.json", "*.md"]
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
local = False
try:
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
except HTTPError:
if local_files_only:
temp_dir = snapshot_download(
repo_id=pretrained_model_name_or_path, cache_dir=cache_dir, local_files_only=local_files_only
)
model_files_info = _get_filepaths_for_folder(temp_dir)
local = True
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if local:
shard_file_present = any(shard_file in k for k in model_files_info)
else:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
@@ -442,6 +454,16 @@ def _get_checkpoint_shard_files(
return cached_filenames, sharded_metadata
def _get_filepaths_for_folder(folder):
relative_paths = []
for root, dirs, files in os.walk(folder):
for fname in files:
abs_path = os.path.join(root, fname)
rel_path = os.path.relpath(abs_path, start=folder)
relative_paths.append(rel_path)
return relative_paths
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")