From c6d0dff4a39137ff206af76b655f7bcf3cadaf32 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 16 Dec 2022 15:28:40 +0100 Subject: [PATCH] Fix ldm tests on master by not running the CPU tests on GPU (#1729) --- .../latent_diffusion/test_latent_diffusion.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index ad8d778072..4b7c89977d 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): return CLIPTextModel(config) def test_inference_text2img(self): + if torch_device != "cpu": + return + unet = self.dummy_cond_unet scheduler = DDIMScheduler() vae = self.dummy_vae @@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" ).images - generator = torch.manual_seed(0) + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" ).images - generator = torch.manual_seed(0) + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = ldm( [prompt], generator=generator, @@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) + + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" ).images @@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) + + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images image_slice = image[0, -3:, -3:, -1]