diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 020411dc78..1cdc02e873 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -448,7 +448,7 @@ def _get_checkpoint_shard_files( _check_if_shards_exist_locally( pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) - return pretrained_model_name_or_path, sharded_metadata + return shards_path, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames @@ -467,35 +467,37 @@ def _get_checkpoint_shard_files( "required according to the checkpoint index." ) - try: - # Load from URL - cached_folder = snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - allow_patterns=allow_patterns, - ignore_patterns=ignore_patterns, - user_agent=user_agent, - ) - if subfolder is not None: - cached_folder = os.path.join(cached_folder, subfolder) + try: + # Load from URL + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + user_agent=user_agent, + ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) - # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so - # we don't have to catch them here. We have also dealt with EntryNotFoundError. - except HTTPError as e: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" - " again after checking your internet connection." - ) from e + # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so + # we don't have to catch them here. We have also dealt with EntryNotFoundError. + except HTTPError as e: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try" + " again after checking your internet connection." + ) from e # If `local_files_only=True`, `cached_folder` may not contain all the shard files. - if local_files_only: + elif local_files_only: _check_if_shards_exist_locally( local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) + if subfolder is not None: + cached_folder = os.path.join(cached_folder, subfolder) return cached_folder, sharded_metadata diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index a84968e613..1c688c9e9c 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_from_hub_local_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True) + loaded_model = loaded_model.to(torch_device) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_device_map_from_hub(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + loaded_model = self.model_class.from_pretrained( + "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto" + ) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu def test_load_sharded_checkpoint_device_map_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) + @require_torch_gpu + def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder") + loaded_model = self.model_class.from_pretrained( + ckpt_path, local_files_only=True, subfolder="unet", device_map="auto" + ) + new_output = loaded_model(**inputs_dict) + + assert loaded_model + assert new_output.sample.shape == (4, 4, 16, 16) + @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()