mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
a few fix for shard checkpoints (#8656)
fix Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -819,7 +819,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hook=force_hook,
|
||||
force_hooks=force_hook,
|
||||
strict=True,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
|
||||
@@ -898,6 +898,7 @@ class ModelTesterMixin:
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir)
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
@@ -933,6 +934,7 @@ class ModelTesterMixin:
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
@@ -1039,6 +1039,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
def test_load_sharded_checkpoint_from_hub(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
@@ -1049,6 +1050,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
|
||||
Reference in New Issue
Block a user