mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix ldm tests on master by not running the CPU tests on GPU (#1729)
This commit is contained in:
committed by
GitHub
parent
a40095dd22
commit
c6d0dff4a3
@@ -72,6 +72,9 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_inference_text2img(self):
|
||||
if torch_device != "cpu":
|
||||
return
|
||||
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vae = self.dummy_vae
|
||||
@@ -91,12 +94,16 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
|
||||
).images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image_from_tuple = ldm(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
@@ -124,7 +131,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image = ldm(
|
||||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
@@ -141,7 +151,10 @@ class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
Reference in New Issue
Block a user