mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
u[
This commit is contained in:
@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
|
||||
class UNetTesterMixin:
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
|
||||
Reference in New Issue
Block a user