diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1b131c74af..e1fcae6266 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -55,6 +55,15 @@ from diffusers.utils.testing_utils import ( from ..others.test_utils import TOKEN, USER, is_staging_test +def caculate_expected_num_shards(index_map_path): + with open(index_map_path) as f: + weight_map_dict = json.load(f)["weight_map"] + first_key = list(weight_map_dict.keys())[0] + weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors + expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) + return expected_num_shards + + # Will be run via run_test_in_subprocess def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): error = None @@ -888,12 +897,7 @@ class ModelTesterMixin: # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. - with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: - weight_map_dict = json.load(f)["weight_map"] - first_key = list(weight_map_dict.keys())[0] - weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors - expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) - + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) self.assertTrue(actual_num_shards == expected_num_shards) @@ -924,12 +928,7 @@ class ModelTesterMixin: # Now check if the right number of shards exists. First, let's get the number of shards. # Since this number can be dependent on the model being tested, it's important that we calculate it # instead of hardcoding it. - with open(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) as f: - weight_map_dict = json.load(f)["weight_map"] - first_key = list(weight_map_dict.keys())[0] - weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors - expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) - + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)) actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")]) self.assertTrue(actual_num_shards == expected_num_shards)