diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index ad8d778072..4b7c89977d 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): return CLIPTextModel(config) def test_inference_text2img(self): + if torch_device != "cpu": + return + unet = self.dummy_cond_unet scheduler = DDIMScheduler() vae = self.dummy_vae @@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" ).images - generator = torch.manual_seed(0) + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" ).images - generator = torch.manual_seed(0) + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = ldm( [prompt], generator=generator, @@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) + + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm( [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" ).images @@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) + + device = torch_device if torch_device != "mps" else "cpu" + generator = torch.Generator(device=device).manual_seed(0) + image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images image_slice = image[0, -3:, -3:, -1]