mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Run ControlNet compile test in a separate subprocess
`torch.compile()` spawns several subprocesses and the GPU memory used was not reclaimed after the test ran. This approach was taken from `transformers`.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
@@ -477,6 +478,50 @@ def pytest_terminal_summary_main(tr, id):
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
|
||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
"""
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
The inputs that will be passed to `target_func` through an (input) queue.
|
||||
timeout (`int`, *optional*, defaults to `None`):
|
||||
The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
|
||||
variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
|
||||
|
||||
start_methohd = "spawn"
|
||||
ctx = multiprocessing.get_context(start_methohd)
|
||||
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
process.start()
|
||||
# Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
|
||||
# the test to exit properly.
|
||||
try:
|
||||
results = output_queue.get(timeout=timeout)
|
||||
output_queue.task_done()
|
||||
except Exception as e:
|
||||
process.terminate()
|
||||
test_case.fail(e)
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if results["error"] is not None:
|
||||
test_case.fail(f'{results["error"]}')
|
||||
|
||||
|
||||
class CaptureLogger:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import traceback
|
||||
from packaging import version
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -32,7 +33,7 @@ from diffusers import (
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, run_test_in_subprocess
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
@@ -45,6 +46,51 @@ torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
@@ -595,41 +641,13 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
def test_stable_diffusion_compile(self):
|
||||
if version.parse(torch.__version__) < version.parse("2.0"):
|
||||
print(f"Test `test_stable_diffusion_ddim` is skipped because {torch.__version__} is < 2.0")
|
||||
print(f"Test `test_stable_diffusion_compile` is skipped because {torch.__version__} is < 2.0")
|
||||
return
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
def test_v11_shuffle_global_pool_conditions(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
|
||||
|
||||
Reference in New Issue
Block a user