From 0feb21a18c44cfbf76a916afead986f04b339292 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 10 Nov 2022 00:09:22 +0100 Subject: [PATCH] [Tests] Fix mps+generator fast tests (#1230) * [Tests] Fix mps+generator fast tests * mps for Euler * retry * warmup issue again? * fix reproducible initial noise * Revert "fix reproducible initial noise" This reverts commit f300d05cb9782ed320064a0c58577a32d4139e6d. * fix reproducible initial noise * fix device --- .github/workflows/pr_tests.yml | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 14 +++++----- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 14 +++++----- tests/pipelines/ddpm/test_ddpm.py | 8 ++++-- tests/test_scheduler.py | 26 +++++++++++++++---- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index c978efe3b7..dc1c482aa0 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -136,7 +136,7 @@ jobs: - name: Run fast PyTorch tests on M1 (MPS) shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index d0bca8038e..6db6298329 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline): if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `generator=torch.Generator(device="{self.device}")` instead.' ) deprecate( @@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(self.device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index d145c5d518..b7194664f4 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline): if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `torch.Generator(device="{self.device}")` instead.' ) deprecate( @@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(self.device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index e16e0d6e8c..14bc094697 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -81,10 +81,14 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): if torch_device == "mps": _ = ddpm(num_inference_steps=1) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_slice = image[0, -3:, -3:, -1] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ab52171511..a9770f0a54 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1281,7 +1281,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): scheduler.set_timesteps(self.num_inference_steps) - generator = torch.Generator(torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1308,7 +1312,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) - generator = torch.Generator(torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1364,7 +1372,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): scheduler.set_timesteps(self.num_inference_steps) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -1381,7 +1393,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - if str(torch_device).startswith("cpu"): + if torch_device in ["cpu", "mps"]: assert abs(result_sum.item() - 152.3192) < 1e-2 assert abs(result_mean.item() - 0.1983) < 1e-3 else: @@ -1396,7 +1408,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): scheduler.set_timesteps(self.num_inference_steps, device=torch_device) - generator = torch.Generator(device=torch_device).manual_seed(0) + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma