From 01aa188d8d908251eaf5a2897eefae63ad4ea3f9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 23 Sep 2025 14:05:51 +0530 Subject: [PATCH] up --- tests/models/test_modeling_common.py | 35 ++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3964dfa7ad..5d02747986 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1947,6 +1947,12 @@ class VAETestMixin: usually don't do slicing and tiling. """ + @staticmethod + def _accepts_generator(model): + model_sig = inspect.signature(model.forward) + accepts_generator = "generator" in model_sig.parameters + return accepts_generator + def test_enable_disable_tiling(self): if not hasattr(self.model_class, "enable_tiling"): pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.") @@ -1957,14 +1963,19 @@ class VAETestMixin: model = self.model_class(**init_dict).to(torch_device) inputs_dict.update({"return_dict": False}) - _ = inputs_dict.pop("generator") + _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) torch.manual_seed(0) - output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling = model(**inputs_dict)[0] torch.manual_seed(0) model.enable_tiling() - output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_tiling = model(**inputs_dict)[0] assert ( output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy() @@ -1972,7 +1983,9 @@ class VAETestMixin: torch.manual_seed(0) model.disable_tiling() - output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_tiling_2 = model(**inputs_dict)[0] assert np.allclose( output_without_tiling.detach().cpu().numpy().all(), @@ -1990,13 +2003,19 @@ class VAETestMixin: inputs_dict.update({"return_dict": False}) _ = inputs_dict.pop("generator", None) + accepts_generator = self._accepts_generator(model) + + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) torch.manual_seed(0) - output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + output_without_slicing = model(**inputs_dict)[0] torch.manual_seed(0) model.enable_slicing() - output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_with_slicing = model(**inputs_dict)[0] assert ( output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy() @@ -2004,7 +2023,9 @@ class VAETestMixin: torch.manual_seed(0) model.disable_slicing() - output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + output_without_slicing_2 = model(**inputs_dict)[0] assert np.allclose( output_without_slicing.detach().cpu().numpy().all(),