1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix ldm tests on master by not running the CPU tests on GPU (#1729)

This commit is contained in:
Patrick von Platen
2022-12-16 15:28:40 +01:00
committed by GitHub
parent a40095dd22
commit c6d0dff4a3

View File

@@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
return CLIPTextModel(config)
def test_inference_text2img(self):
if torch_device != "cpu":
return
unet = self.dummy_cond_unet
scheduler = DDIMScheduler()
vae = self.dummy_vae
@@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
).images
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
).images
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ldm(
[prompt],
generator=generator,
@@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
).images
@@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]