From e05f03ae41540123a99afdacff86d26170c1315d Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 28 Jul 2022 09:29:29 +0200 Subject: [PATCH] Disable test_ddpm_ddim_equality_batched until resolved (#142) disable test_ddpm_ddim_equality_batched --- tests/test_modeling_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index c47a787c48..ed98a9e536 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -894,10 +894,10 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"] - # the values aren't exactly equal, but the images look the same upon visual inspection + # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_image - ddim_image).max() < 1e-1 - @slow + @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation") def test_ddpm_ddim_equality_batched(self): model_id = "google/ddpm-cifar10-32" @@ -909,12 +909,12 @@ class PipelineTesterMixin(unittest.TestCase): ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) generator = torch.manual_seed(0) - ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy")["sample"] + ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"] generator = torch.manual_seed(0) - ddim_images = ddim(batch_size=2, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[ + ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[ "sample" ] - # the values aren't exactly equal, but the images look the same upon visual inspection + # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1