From 86d4c5a25428c801fa720f7150a8bafeb6ff98e2 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 9 Nov 2022 23:47:49 +0100 Subject: [PATCH] warmup issue again? --- .github/workflows/pr_tests.yml | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- tests/pipelines/ddpm/test_ddpm.py | 8 +++----- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index c978efe3b7..dc1c482aa0 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -136,7 +136,7 @@ jobs: - name: Run fast PyTorch tests on M1 (MPS) shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index d0bca8038e..79ab9e2dc8 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline): if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `generator=torch.Generator(device="{self.device}")` instead.' ) deprecate( diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index d145c5d518..04b7e65f48 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline): if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": message = ( f"The `generator` device is `{generator.device}` and does not match the pipeline " - f"device `{self.device}`, so the `generator` will be set to `None`. " + f"device `{self.device}`, so the `generator` will be ignored. " f'Please use `torch.Generator(device="{self.device}")` instead.' ) deprecate( diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index f4517907cd..4d59d08c93 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -80,7 +80,6 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) - if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -88,11 +87,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + # Warmup pass when using mps (see #372) if torch_device == "mps": - # device type MPS is not supported for torch.Generator() api. - generator = torch.manual_seed(0) - else: - generator = torch.Generator(device=torch_device).manual_seed(0) + _ = ddpm(num_inference_steps=1) + generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] image_slice = image[0, -3:, -3:, -1]