mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix load sharded checkpoint from a subfolder (local path) (#8913)
fix Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
|
||||
_check_if_shards_exist_locally(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
|
||||
)
|
||||
return pretrained_model_name_or_path, sharded_metadata
|
||||
return shards_path, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
allow_patterns = original_shard_filenames
|
||||
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
try:
|
||||
# Load from URL
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
) from e
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
|
||||
except HTTPError as e:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
|
||||
" again after checking your internet connection."
|
||||
) from e
|
||||
|
||||
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
|
||||
if local_files_only:
|
||||
elif local_files_only:
|
||||
_check_if_shards_exist_locally(
|
||||
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
|
||||
)
|
||||
if subfolder is not None:
|
||||
cached_folder = os.path.join(cached_folder, subfolder)
|
||||
|
||||
return cached_folder, sharded_metadata
|
||||
|
||||
|
||||
@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user