mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Chore] create a utility for calculating the expected number of shards. (#8692)
create a utility for calculating the expected number of shards.
This commit is contained in:
@@ -55,6 +55,15 @@ from diffusers.utils.testing_utils import (
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
def caculate_expected_num_shards(index_map_path):
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
@@ -888,12 +897,7 @@ class ModelTesterMixin:
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
@@ -924,12 +928,7 @@ class ModelTesterMixin:
|
||||
# Now check if the right number of shards exists. First, let's get the number of shards.
|
||||
# Since this number can be dependent on the model being tested, it's important that we calculate it
|
||||
# instead of hardcoding it.
|
||||
with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
|
||||
expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user