mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Run torch.compile tests in separate subprocesses (#3503)
* Run ControlNet compile test in a separate subprocess
`torch.compile()` spawns several subprocesses and the GPU memory used
was not reclaimed after the test ran. This approach was taken from
`transformers`.
* Style
* Prepare a couple more compile tests to run in subprocess.
* Use require_torch_2 decorator.
* Test inpaint_compile in subprocess.
* Run img2img compile test in subprocess.
* Run stable diffusion compile test in subprocess.
* style
* Temporarily trigger on pr to test.
* Revert "Temporarily trigger on pr to test."
This reverts commit 82d76868dd.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -477,6 +478,50 @@ def pytest_terminal_summary_main(tr, id):
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
|
||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
"""
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
The inputs that will be passed to `target_func` through an (input) queue.
|
||||
timeout (`int`, *optional*, defaults to `None`):
|
||||
The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
|
||||
variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
|
||||
|
||||
start_methohd = "spawn"
|
||||
ctx = multiprocessing.get_context(start_methohd)
|
||||
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
process.start()
|
||||
# Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
|
||||
# the test to exit properly.
|
||||
try:
|
||||
results = output_queue.get(timeout=timeout)
|
||||
output_queue.task_done()
|
||||
except Exception as e:
|
||||
process.terminate()
|
||||
test_case.fail(e)
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if results["error"] is not None:
|
||||
test_case.fail(f'{results["error"]}')
|
||||
|
||||
|
||||
class CaptureLogger:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -27,7 +28,31 @@ from requests.exceptions import HTTPError
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import logging, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
init_dict, model_class = in_queue.get(timeout=timeout)
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model = torch.compile(model)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
assert new_model.__class__ == model_class
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
@@ -235,20 +260,11 @@ class ModelTesterMixin:
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model = torch.compile(model)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
|
||||
assert new_model.__class__ == self.model_class
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
inputs = [init_dict, self.model_class]
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs)
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -32,7 +32,12 @@ from diffusers import (
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
@@ -44,6 +49,51 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -594,41 +644,9 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
|
||||
def test_v11_shuffle_global_pool_conditions(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
|
||||
|
||||
@@ -15,19 +15,14 @@
|
||||
|
||||
|
||||
import gc
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from packaging import version
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -44,25 +39,52 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
enable_full_determinism,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
from ...models.test_models_unet_2d_condition import create_lora_layers
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def process_fixture():
|
||||
# This will be run before each test
|
||||
command = [sys.executable, os.path.abspath(__file__)]
|
||||
process = subprocess.Popen(command)
|
||||
enable_full_determinism()
|
||||
yield process
|
||||
# This will be run after each test
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
os.kill(process.pid, signal.SIGTERM) # or signal.SIGKILL
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
|
||||
sd_pipe.unet.to(memory_format=torch.channels_last)
|
||||
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
|
||||
assert np.abs(image_slice - expected_slice).max() < 5e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
@@ -927,27 +949,15 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 8e-1
|
||||
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
|
||||
sd_pipe.unet.to(memory_format=torch.channels_last)
|
||||
sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
|
||||
assert np.abs(image_slice - expected_slice).max() < 5e-3
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -34,7 +34,13 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
@@ -47,6 +53,38 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_img2img_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
@@ -464,27 +502,15 @@ class StableDiffusionImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
|
||||
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
|
||||
|
||||
@require_torch_2
|
||||
def test_img2img_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-3
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_img2img_compile, inputs=inputs)
|
||||
|
||||
|
||||
@nightly
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -33,7 +33,12 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
from ...models.test_models_unet_2d_condition import create_lora_layers
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
@@ -43,6 +48,40 @@ from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMix
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_inpaint_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
inputs = in_queue.get(timeout=timeout)
|
||||
torch_device = inputs.pop("torch_device")
|
||||
seed = inputs.pop("seed")
|
||||
inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
@@ -315,29 +354,15 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
# make sure that less than 2.2 GB is allocated
|
||||
assert mem_bytes < 2.2 * 10**9
|
||||
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
seed = 0
|
||||
inputs = self.get_inputs(torch_device, seed=seed)
|
||||
# Can't pickle a Generator object
|
||||
del inputs["generator"]
|
||||
inputs["torch_device"] = torch_device
|
||||
inputs["seed"] = seed
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs)
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
|
||||
@@ -20,6 +20,7 @@ import random
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
@@ -73,12 +74,54 @@ from diffusers.utils.testing_utils import (
|
||||
require_compel,
|
||||
require_flax,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
model = torch.compile(model)
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
new_ddpm.to(torch_device)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class CustomEncoder(ModelMixin, ConfigMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1342,35 +1385,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
@require_torch_2
|
||||
def test_from_save_pretrained_dynamo(self):
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
model = torch.compile(model)
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
new_ddpm.to(torch_device)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
Reference in New Issue
Block a user