1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] Make sure tests are on GPU (#269)

* [Tests] Make sure tests are on GPU

* move more models

* speed up tests
This commit is contained in:
Patrick von Platen
2022-08-29 15:58:11 +02:00
committed by GitHub
parent 16172c1c7e
commit 9e1b1ca49d
5 changed files with 47 additions and 14 deletions

View File

@@ -24,6 +24,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -133,18 +136,20 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained(self):
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
model.eval()
model.to(torch_device)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
noise = noise.to(torch_device)
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output_slice = output[0, -1, -3:, -3:].flatten()
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on

View File

@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
@@ -74,6 +77,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained(self):
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
model = model.to(torch_device)
model.eval()
torch.manual_seed(0)
@@ -81,10 +85,11 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image, sample_posterior=True)
output_slice = output[0, -1, -3:, -3:].flatten()
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03])
# fmt: on

View File

@@ -23,6 +23,9 @@ from diffusers.testing_utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VQModelTests(ModelTesterMixin, unittest.TestCase):
model_class = VQModel
@@ -73,17 +76,18 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.eval()
model.to(torch_device).eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image)
output_slice = output[0, -1, -3:, -3:].flatten()
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
# fmt: on

View File

@@ -59,10 +59,12 @@ class PipelineTesterMixin(unittest.TestCase):
schedular = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular)
ddpm.to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
@@ -76,11 +78,12 @@ class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32"
ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm.scheduler.num_timesteps = 10
ddpm_from_hub.scheduler.num_timesteps = 10
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm.to(torch_device)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device)
generator = torch.manual_seed(0)
@@ -94,14 +97,15 @@ class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_hub_pass_model(self):
model_path = "google/ddpm-cifar10-32"
scheduler = DDPMScheduler(num_train_timesteps=10)
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
ddpm_from_hub.scheduler.num_timesteps = 10
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device)
generator = torch.manual_seed(0)
@@ -116,6 +120,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
@@ -141,6 +146,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -159,6 +165,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -177,6 +184,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
@@ -195,6 +203,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
generator = torch.manual_seed(0)
image = pndm(generator=generator, output_type="numpy")["sample"]
@@ -207,6 +216,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@@ -223,6 +233,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
@@ -290,6 +301,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = ScoreSdeVeScheduler.from_config(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
@@ -304,6 +316,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_uncond(self):
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
ldm.to(torch_device)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
@@ -323,7 +336,9 @@ class PipelineTesterMixin(unittest.TestCase):
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
generator = torch.manual_seed(0)
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
@@ -343,7 +358,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
generator = torch.manual_seed(0)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
@@ -363,6 +381,7 @@ class PipelineTesterMixin(unittest.TestCase):
scheduler = KarrasVeScheduler(tensor_format="pt")
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]