mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Test inpaint_compile in subprocess.
This commit is contained in:
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
import traceback
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_2, require_torch_gpu, run_test_in_subprocess
|
||||
|
||||
from ...models.test_models_unet_2d_condition import create_lora_layers
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
@@ -42,6 +42,40 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_inpaint_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
@@ -315,29 +349,15 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.2 GB is allocated
|
||||
assert mem_bytes < 2.2 * 10**9
|
||||
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs)
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user