mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix loading sharded checkpoints from subfolder (#8798)
* fix load sharded checkpoints from subfolder{
* style
* os.path.join
* add a small test
---------
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -221,7 +221,7 @@ def _fetch_index_file(
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
@@ -455,10 +455,13 @@ def _get_checkpoint_shard_files(
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
allow_patterns = original_shard_filenames
|
||||
if subfolder is not None:
|
||||
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
|
||||
|
||||
ignore_patterns = ["*.json", "*.md"]
|
||||
if not local_files_only:
|
||||
# `model_info` call must guarded with the above condition.
|
||||
model_files_info = model_info(pretrained_model_name_or_path)
|
||||
model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
|
||||
for shard_file in original_shard_filenames:
|
||||
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
||||
if not shard_file_present:
|
||||
@@ -481,6 +484,8 @@ def _get_checkpoint_shard_files(
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
|
||||
@@ -1045,6 +1045,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet"
|
||||
)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user