diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 080158de41..aac6474c48 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase): class UNetTesterMixin: + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16