1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Prepare a couple more compile tests to run in subprocess.

This commit is contained in:
Pedro Cuenca
2023-05-22 16:04:30 +02:00
parent 1b58d337e3
commit 902af3ddea
2 changed files with 75 additions and 44 deletions

View File

@@ -22,12 +22,37 @@ from typing import Dict, List, Tuple
import numpy as np
import requests_mock
import torch
import traceback
from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import logging, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
try:
init_dict, model_class = in_queue.get(timeout=timeout)
model = model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert new_model.__class__ == model_class
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class ModelUtilsTest(unittest.TestCase):
@@ -234,21 +259,12 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@require_torch_gpu
@require_torch_2
def test_from_save_pretrained_dynamo(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert new_model.__class__ == self.model_class
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
inputs = [init_dict, self.model_class]
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs)
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -28,6 +28,7 @@ import PIL
import requests_mock
import safetensors.torch
import torch
import traceback
from parameterized import parameterized
from PIL import Image
from requests.exceptions import HTTPError
@@ -71,12 +72,54 @@ from diffusers.utils.testing_utils import (
require_compel,
require_flax,
require_torch_gpu,
run_test_in_subprocess,
)
enable_full_determinism()
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
try:
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
except Exception:
error = f"{traceback.format_exc()}"
results = {"error": error}
out_queue.put(results, timeout=timeout)
out_queue.join()
class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
@@ -1315,35 +1358,7 @@ class PipelineSlowTests(unittest.TestCase):
@require_torch_2
def test_from_save_pretrained_dynamo(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32"