1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Chore] create a utility for calculating the expected number of shards. (#8692)

create a utility for calculating the expected number of shards.
This commit is contained in:
Sayak Paul
2024-06-25 17:05:39 +05:30
committed by GitHub
parent 1f81fbe274
commit 4ad7a1f5fd

View File

@@ -55,6 +55,15 @@ from diffusers.utils.testing_utils import (
from ..others.test_utils import TOKEN, USER, is_staging_test
def caculate_expected_num_shards(index_map_path):
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
@@ -888,12 +897,7 @@ class ModelTesterMixin:
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
@@ -924,12 +928,7 @@ class ModelTesterMixin:
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)