diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ff52ee701d..e1752febcc 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1936,6 +1936,9 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) def test_torch_compile_repeated_blocks(self): + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as `_repeated_blocks` is not set in the model class.") + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device)