From 6a01c4681cf80fded080cb01a8754cac025894f5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 24 Sep 2025 09:30:34 +0530 Subject: [PATCH] u[ --- tests/models/test_modeling_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 080158de41..aac6474c48 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase): class UNetTesterMixin: + @staticmethod + def _accepts_norm_num_groups(model_class): + model_sig = inspect.signature(model_class.__init__) + accepts_norm_groups = "norm_num_groups" in model_sig.parameters + return accepts_norm_groups + def test_forward_with_norm_groups(self): + if not self._accepts_norm_num_groups(self.model_class): + pytest.skip(f"Test not supported for {self.model_class.__name__}") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16