1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

a few fix for shard checkpoints (#8656)

fix

Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
YiYi Xu
2024-06-20 21:20:58 -10:00
committed by GitHub
parent adc31940a9
commit c71c19c5e6
3 changed files with 5 additions and 1 deletions

View File

@@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hook=force_hook,
force_hooks=force_hook,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()

View File

@@ -898,6 +898,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir)
new_model = new_model.to(torch_device)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
@@ -933,6 +934,7 @@ class ModelTesterMixin:
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
new_model = new_model.to(torch_device)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)

View File

@@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def test_load_sharded_checkpoint_from_hub(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
@@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model