mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -1947,6 +1947,12 @@ class VAETestMixin:
|
||||
usually don't do slicing and tiling.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _accepts_generator(model):
|
||||
model_sig = inspect.signature(model.forward)
|
||||
accepts_generator = "generator" in model_sig.parameters
|
||||
return accepts_generator
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
@@ -1957,14 +1963,19 @@ class VAETestMixin:
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator")
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_tiling = model(**inputs_dict)[0]
|
||||
|
||||
assert (
|
||||
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
|
||||
@@ -1972,7 +1983,9 @@ class VAETestMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert np.allclose(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
@@ -1990,13 +2003,19 @@ class VAETestMixin:
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
|
||||
assert (
|
||||
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
|
||||
@@ -2004,7 +2023,9 @@ class VAETestMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert np.allclose(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
|
||||
Reference in New Issue
Block a user