From c71c19c5e692715dc7a75771936c40201eac9409 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 20 Jun 2024 21:20:58 -1000 Subject: [PATCH] a few fix for shard checkpoints (#8656) fix Co-authored-by: yiyixuxu --- src/diffusers/models/modeling_utils.py | 2 +- tests/models/test_modeling_common.py | 2 ++ tests/models/unets/test_models_unet_2d_condition.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ab98d4cea5..d4851ab403 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, - force_hook=force_hook, + force_hooks=force_hook, strict=True, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5a1901d49a..1b131c74af 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -898,6 +898,7 @@ class ModelTesterMixin: self.assertTrue(actual_num_shards == expected_num_shards) new_model = self.model_class.from_pretrained(tmp_dir) + new_model = new_model.to(torch_device) torch.manual_seed(0) new_output = new_model(**inputs_dict) @@ -933,6 +934,7 @@ class ModelTesterMixin: self.assertTrue(actual_num_shards == expected_num_shards) new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto") + new_model = new_model.to(torch_device) torch.manual_seed(0) new_output = new_model(**inputs_dict) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index dd524c5b86..63e66dabf0 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test def test_load_sharded_checkpoint_from_hub(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) assert loaded_model @@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) + loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) assert loaded_model