diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 79ab9e2dc8..c68e824089 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 04b7e65f48..f28f4406e7 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), - generator=generator, - device=self.device, - ) + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = torch.randn(image_shape, generator=generator) + image = image.to(device) + else: + image = torch.randn(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 4d59d08c93..e335c57078 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -20,7 +20,7 @@ import torch from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.utils import deprecate -from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import require_torch, slow, torch_device from ...test_pipelines_common import PipelineTesterMixin @@ -44,18 +44,21 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): return model def test_inference(self): - device = "cpu" unet = self.dummy_uncond_unet scheduler = DDPMScheduler() ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(device) + ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - generator = torch.Generator(device=device).manual_seed(0) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - generator = torch.Generator(device=device).manual_seed(0) + generator = torch.manual_seed(0) image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] @@ -65,8 +68,9 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_slice = np.array( [5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_inference_predict_epsilon(self): deprecate("remove this test", "0.10.0", "remove") @@ -80,6 +84,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # Warmup pass when using mps (see #372) if torch_device == "mps": _ = ddpm(num_inference_steps=1) + if torch_device == "mps": # device type MPS is not supported for torch.Generator() api. generator = torch.manual_seed(0) @@ -87,9 +92,6 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) generator = generator.manual_seed(0) image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0] @@ -102,7 +104,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @slow -@require_torch_gpu +@require_torch class DDPMPipelineIntegrationTests(unittest.TestCase): def test_inference_cifar10(self): model_id = "google/ddpm-cifar10-32" @@ -114,11 +116,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.manual_seed(0) image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020]) + expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2