diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index bebc0febe6..f1e977f0db 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -109,7 +109,7 @@ def check_if_lora_correctly_set(model) -> bool: def normalize_output(out): - out0 = out[0] + out0 = out[0] if isinstance(out, tuple) else out return torch.stack(out0) if isinstance(out0, list) else out0 @@ -541,10 +541,8 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - if isinstance(image, list): - image = torch.stack(image) - if isinstance(new_image, list): - new_image = torch.stack(new_image) + image = normalize_output(image) + new_image = normalize_output(new_image) max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -790,10 +788,8 @@ class ModelTesterMixin: if isinstance(new_image, dict): new_image = new_image.to_tuple()[0] - if isinstance(image, list): - image = torch.stack(image) - if isinstance(new_image, list): - new_image = torch.stack(new_image) + image = normalize_output(image) + new_image = normalize_output(new_image) max_diff = (image - new_image).abs().max().item() self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") @@ -857,10 +853,8 @@ class ModelTesterMixin: if isinstance(second, dict): second = second.to_tuple()[0] - if isinstance(first, list): - first = torch.stack(first) - if isinstance(second, list): - second = torch.stack(second) + first = normalize_output(first) + second = normalize_output(second) out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() @@ -1349,15 +1343,16 @@ class ModelTesterMixin: max_memory = {0: max_size, "cpu": model_size * 2} new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) - # Offload check + # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0) new_output = new_model(**inputs_dict) - new_normalized_out = normalize_output(new_output) + new_normalized_output = normalize_output(new_output) - self.assertTrue(torch.allclose(base_normalized_output, new_normalized_out, atol=1e-5)) + self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5)) @require_torch_accelerator def test_disk_offload_without_safetensors(self): @@ -1680,8 +1675,7 @@ class ModelTesterMixin: model.eval() model.to(torch_device) base_slice = model(**inputs_dict)[0] - if isinstance(base_slice, list): - base_slice = torch.stack(base_slice) + base_slice = normalize_output(base_slice) base_slice = base_slice.detach().flatten().cpu().numpy() def check_linear_dtype(module, storage_dtype, compute_dtype): @@ -1709,8 +1703,7 @@ class ModelTesterMixin: check_linear_dtype(model, storage_dtype, compute_dtype) output = model(**inputs_dict)[0] - if isinstance(output, list): - output = torch.stack(output) + output = normalize_output(output) output = output.float().flatten().detach().cpu().numpy() # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. @@ -1800,29 +1793,25 @@ class ModelTesterMixin: model.to(torch_device) output_without_group_offloading = run_forward(model) - if isinstance(output_without_group_offloading, list): - output_without_group_offloading = torch.stack(output_without_group_offloading) + output_without_group_offloading = normalize_output(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading1 = run_forward(model) - if isinstance(output_with_group_offloading1, list): - output_with_group_offloading1 = torch.stack(output_with_group_offloading1) + output_with_group_offloading1 = normalize_output(output_with_group_offloading1) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) output_with_group_offloading2 = run_forward(model) - if isinstance(output_with_group_offloading2, list): - output_with_group_offloading2 = torch.stack(output_with_group_offloading2) + output_with_group_offloading2 = normalize_output(output_with_group_offloading2) torch.manual_seed(0) model = self.model_class(**init_dict) model.enable_group_offload(torch_device, offload_type="leaf_level") output_with_group_offloading3 = run_forward(model) - if isinstance(output_with_group_offloading3, list): - output_with_group_offloading3 = torch.stack(output_with_group_offloading3) + output_with_group_offloading3 = normalize_output(output_with_group_offloading3) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1830,8 +1819,7 @@ class ModelTesterMixin: torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream ) output_with_group_offloading4 = run_forward(model) - if isinstance(output_with_group_offloading4, list): - output_with_group_offloading4 = torch.stack(output_with_group_offloading4) + output_with_group_offloading4 = normalize_output(output_with_group_offloading4) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) @@ -1902,8 +1890,7 @@ class ModelTesterMixin: model.eval() model.to(torch_device) output_without_group_offloading = _run_forward(model, inputs_dict) - if isinstance(output_without_group_offloading, list): - output_without_group_offloading = torch.stack(output_without_group_offloading) + output_without_group_offloading = normalize_output(output_without_group_offloading) torch.manual_seed(0) model = self.model_class(**init_dict) @@ -1939,8 +1926,7 @@ class ModelTesterMixin: raise ValueError(f"Following files are missing: {', '.join(missing_files)}") output_with_group_offloading = _run_forward(model, inputs_dict) - if isinstance(output_with_group_offloading, list): - output_with_group_offloading = torch.stack(output_with_group_offloading) + output_with_group_offloading = normalize_output(output_with_group_offloading) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol)) def test_auto_model(self, expected_max_diff=5e-5):