From 902af3ddea09b3357a3f8ccfeda0657ee52f83b1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 22 May 2023 16:04:30 +0200 Subject: [PATCH] Prepare a couple more compile tests to run in subprocess. --- tests/models/test_modeling_common.py | 46 ++++++++++++------ tests/pipelines/test_pipelines.py | 73 +++++++++++++++++----------- 2 files changed, 75 insertions(+), 44 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index b2c5f2d79d..98654296d7 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -22,12 +22,37 @@ from typing import Dict, List, Tuple import numpy as np import requests_mock import torch +import traceback from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import logging, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess + + +# Will be run via run_test_in_subprocess +def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): + error = None + try: + init_dict, model_class = in_queue.get(timeout=timeout) + + model = model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == model_class + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() class ModelUtilsTest(unittest.TestCase): @@ -234,21 +259,12 @@ class ModelTesterMixin: max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") - - @require_torch_gpu + + @require_torch_2 def test_from_save_pretrained_dynamo(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model = torch.compile(model) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - assert new_model.__class__ == self.model_class + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + inputs = [init_dict, self.model_class] + run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs) def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index a9abb0b4fb..5851680043 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -28,6 +28,7 @@ import PIL import requests_mock import safetensors.torch import torch +import traceback from parameterized import parameterized from PIL import Image from requests.exceptions import HTTPError @@ -71,12 +72,54 @@ from diffusers.utils.testing_utils import ( require_compel, require_flax, require_torch_gpu, + run_test_in_subprocess, ) enable_full_determinism() +# Will be run via run_test_in_subprocess +def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): + error = None + try: + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + class DownloadTests(unittest.TestCase): def test_one_request_upon_cached(self): # TODO: For some reason this test fails on MPS where no HEAD call is made. @@ -1315,35 +1358,7 @@ class PipelineSlowTests(unittest.TestCase): @require_torch_2 def test_from_save_pretrained_dynamo(self): - # 1. Load models - model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=3, - out_channels=3, - down_block_types=("DownBlock2D", "AttnDownBlock2D"), - up_block_types=("AttnUpBlock2D", "UpBlock2D"), - ) - model = torch.compile(model) - scheduler = DDPMScheduler(num_train_timesteps=10) - - ddpm = DDPMPipeline(model, scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - new_ddpm.to(torch_device) - - generator = torch.Generator(device=torch_device).manual_seed(0) - image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images - - generator = torch.Generator(device=torch_device).manual_seed(0) - new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None) def test_from_pretrained_hub(self): model_path = "google/ddpm-cifar10-32"