1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] fix: shard loading and saving when variant is provided. (#8869)

fix: shard loading and saving when variant is provided.
This commit is contained in:
Sayak Paul
2024-07-17 08:26:28 +05:30
committed by GitHub
parent f6cfe0a1e5
commit 0f09b01ab3
2 changed files with 40 additions and 1 deletions

View File

@@ -271,7 +271,8 @@ if cache_version < 1:
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
split_index = -2 if weights_name.endswith(".index.json") else -1
splits = splits[:-split_index] + [variant] + splits[-split_index:]
weights_name = ".".join(splits)
return weights_name

View File

@@ -40,6 +40,7 @@ from diffusers.models.attention_processor import (
)
from diffusers.training_utils import EMAModel
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
get_python_version,
@@ -915,6 +916,43 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
# It doesn't matter if the actual model is in fp16 or not. Just adding the variant and
# testing if loading works with the variant when the checkpoint is sharded should be
# enough.
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
def test_sharded_checkpoints_device_map(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()