1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-09-23 14:05:51 +05:30
parent 490c4761b4
commit 01aa188d8d

View File

@@ -1947,6 +1947,12 @@ class VAETestMixin:
usually don't do slicing and tiling.
"""
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
@@ -1957,14 +1963,19 @@ class VAETestMixin:
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
assert (
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
@@ -1972,7 +1983,9 @@ class VAETestMixin:
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
assert np.allclose(
output_without_tiling.detach().cpu().numpy().all(),
@@ -1990,13 +2003,19 @@ class VAETestMixin:
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
output_without_slicing = model(**inputs_dict)[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
assert (
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
@@ -2004,7 +2023,9 @@ class VAETestMixin:
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
assert np.allclose(
output_without_slicing.detach().cpu().numpy().all(),