1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Tests] Speed up some fast pipeline tests (#7477)

* speed up test_vae_slicing in animatediff

* speed up test_karras_schedulers_shape for attend and excite.

* style.

* get the static slices out.

* specify torch print options.

* modify

* test run with controlnet

* specify kwarg

* fix: things

* not None

* flatten

* controlnet img2img

* complete controlet sd

* finish more

* finish more

* finish more

* finish more

* finish the final batch

* add cpu check for expected_pipe_slice.

* finish the rest

* remove print

* style

* fix ssd1b controlnet test

* checking ssd1b

* disable the test.

* make the test_ip_adapter_single controlnet test more robust

* fix: simple inpaint

* multi

* disable panorama

* enable again

* panorama is shaky so leave it for now

* remove print

* raise tolerance.
This commit is contained in:
Sayak Paul
2024-03-29 14:11:38 +05:30
committed by GitHub
parent 34c90dbb31
commit fac761694a
19 changed files with 263 additions and 15 deletions

View File

@@ -105,10 +105,21 @@ def numpy_cosine_similarity_distance(a, b):
return distance
def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
def print_tensor_test(
tensor,
limit_to_slices=None,
max_torch_print=None,
filename="test_corrections.txt",
expected_tensor_name="expected_slice",
):
if max_torch_print:
torch.set_printoptions(threshold=10_000)
test_name = os.environ.get("PYTEST_CURRENT_TEST")
if not torch.is_tensor(tensor):
tensor = torch.from_numpy(tensor)
if limit_to_slices:
tensor = tensor[0, -3:, -3:, -1]
tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
# format is usually:
@@ -117,7 +128,7 @@ def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_n
test_file, test_class, test_fn = test_name.split("::")
test_fn = test_fn.split()[0]
with open(filename, "a") as f:
print(";".join([test_file, test_class, test_fn, output_str]), file=f)
print("::".join([test_file, test_class, test_fn, output_str]), file=f)
def get_tests_dir(append_path=None):

View File

@@ -131,6 +131,42 @@ class AnimateDiffPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[
0.5541,
0.5802,
0.5074,
0.4583,
0.4729,
0.5374,
0.4051,
0.4495,
0.4480,
0.5292,
0.6322,
0.6265,
0.5455,
0.4771,
0.5795,
0.5845,
0.4172,
0.6066,
0.6535,
0.4113,
0.6833,
0.5736,
0.3589,
0.5730,
0.4205,
0.3786,
0.5323,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_inference_batch_single_identical(
self,
batch_size=2,
@@ -299,6 +335,9 @@ class AnimateDiffPipelineFastTests(
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
@slow
@require_torch_gpu

View File

@@ -135,6 +135,34 @@ class AnimateDiffVideoToVideoPipelineFastTests(IPAdapterTesterMixin, PipelineTes
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[
0.4947,
0.4780,
0.4340,
0.4666,
0.4028,
0.4645,
0.4915,
0.4101,
0.4308,
0.4581,
0.3582,
0.4953,
0.4466,
0.5348,
0.5863,
0.5299,
0.5213,
0.5017,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_inference_batch_single_identical(
self,
batch_size=2,

View File

@@ -221,6 +221,12 @@ class ControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5234, 0.3333, 0.1745, 0.7605, 0.6224, 0.4637, 0.6989, 0.7526, 0.4665])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -455,6 +461,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.2422, 0.3425, 0.4048, 0.5351, 0.3503, 0.2419, 0.4645, 0.4570, 0.3804])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -668,6 +680,12 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5264, 0.3203, 0.1602, 0.8235, 0.6332, 0.4593, 0.7226, 0.7777, 0.4780])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)

View File

@@ -174,6 +174,12 @@ class ControlNetImg2ImgPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.7096, 0.5149, 0.3571, 0.5897, 0.4715, 0.4052, 0.6098, 0.6886, 0.4213])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -366,6 +372,12 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5293, 0.7339, 0.6642, 0.3950, 0.5212, 0.5175, 0.7002, 0.5907, 0.5182])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)

View File

@@ -191,6 +191,15 @@ class StableDiffusionXLControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self, from_ssd1b=False, expected_pipe_slice=None):
if not from_ssd1b:
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[0.7331, 0.5907, 0.5667, 0.6029, 0.5679, 0.5968, 0.4033, 0.4761, 0.5090]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -1042,6 +1051,12 @@ class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNe
# make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.6832, 0.5703, 0.5460, 0.6300, 0.5856, 0.6034, 0.4494, 0.4613, 0.5036])
return super().test_ip_adapter_single(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice)
def test_controlnet_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

View File

@@ -170,6 +170,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
return inputs
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.6265, 0.5441, 0.5384, 0.5446, 0.5810, 0.5908, 0.5414, 0.5428, 0.5353])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_xl_controlnet_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -108,6 +108,12 @@ class LatentConsistencyModelPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_lcm_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

View File

@@ -119,6 +119,12 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.4003, 0.3718, 0.2863, 0.5500, 0.5587, 0.3772, 0.4617, 0.4961, 0.4417])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_lcm_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

View File

@@ -138,6 +138,43 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.T
assert isinstance(pipe.unet, UNetMotionModel)
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[
0.5609,
0.5756,
0.4830,
0.4420,
0.4547,
0.5129,
0.3779,
0.4042,
0.3772,
0.4450,
0.5710,
0.5536,
0.4835,
0.4308,
0.5578,
0.5578,
0.4395,
0.5440,
0.6051,
0.4651,
0.6258,
0.5662,
0.3988,
0.5108,
0.4153,
0.3993,
0.4803,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass

View File

@@ -370,6 +370,12 @@ class StableDiffusionPipelineFastTests(
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

View File

@@ -253,6 +253,12 @@ class StableDiffusionImg2ImgPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.4932, 0.5092, 0.5135, 0.5517, 0.5626, 0.6621, 0.6490, 0.5021, 0.5441])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_img2img_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -388,6 +388,15 @@ class StableDiffusionInpaintPipelineFastTests(
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
def test_ip_adapter_single(self, from_simple=False, expected_pipe_slice=None):
if not from_simple:
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[0.4390, 0.5452, 0.3772, 0.5448, 0.6031, 0.4480, 0.5194, 0.4687, 0.4640]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
pipeline_class = StableDiffusionInpaintPipeline
@@ -475,6 +484,12 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
}
return inputs
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.6345, 0.5395, 0.5611, 0.5403, 0.5830, 0.5855, 0.5193, 0.5443, 0.5211])
return super().test_ip_adapter_single(from_simple=True, expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -185,6 +185,9 @@ class StableDiffusionAttendAndExcitePipelineFastTests(
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=4e-4)
def test_karras_schedulers_shape(self):
super().test_karras_schedulers_shape(num_inference_steps_for_strength_for_iterations=3)
@require_torch_gpu
@nightly

View File

@@ -292,6 +292,12 @@ class StableDiffusionXLPipelineFastTests(
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5552, 0.5569, 0.4725, 0.4348, 0.4994, 0.4632, 0.5142, 0.5012, 0.4700])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)

View File

@@ -294,6 +294,15 @@ class StableDiffusionXLAdapterPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self, from_multi=False, expected_pipe_slice=None):
if not from_multi:
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[0.5753, 0.6022, 0.4728, 0.4986, 0.5708, 0.4645, 0.5194, 0.5134, 0.4730]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -446,6 +455,12 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5813, 0.6100, 0.4756, 0.5057, 0.5720, 0.4632, 0.5177, 0.5125, 0.4718])
return super().test_ip_adapter_single(from_multi=True, expected_pipe_slice=expected_pipe_slice)
def test_inference_batch_consistent(
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
):

View File

@@ -311,6 +311,12 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5174, 0.4512, 0.5006, 0.6273, 0.5160, 0.6825, 0.6655, 0.5840, 0.5675])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()

View File

@@ -223,6 +223,12 @@ class StableDiffusionXLInpaintPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.7971, 0.5371, 0.5973, 0.5642, 0.6689, 0.6894, 0.5770, 0.6063, 0.5261])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")

View File

@@ -37,11 +37,7 @@ from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import (
CaptureLogger,
require_torch,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
from ..models.autoencoders.test_models_vae import (
get_asym_autoencoder_kl_config,
@@ -71,7 +67,7 @@ class SDFunctionTesterMixin:
It provides a set of common tests for PyTorch pipeline that inherit from StableDiffusionMixin, e.g. vae_slicing, vae_tiling, freeu, etc.
"""
def test_vae_slicing(self):
def test_vae_slicing(self, image_count=4):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
@@ -79,8 +75,6 @@ class SDFunctionTesterMixin:
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image_count = 4
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * image_count
if "image" in inputs: # fix batch size mismatch in I2V_Gen pipeline
@@ -241,7 +235,11 @@ class IPAdapterTesterMixin:
inputs["return_dict"] = False
return inputs
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4):
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -249,7 +247,10 @@ class IPAdapterTesterMixin:
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
@@ -259,12 +260,16 @@ class IPAdapterTesterMixin:
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
@@ -514,7 +519,9 @@ class PipelineKarrasSchedulerTesterMixin:
equivalence of dict and tuple outputs, etc.
"""
def test_karras_schedulers_shape(self):
def test_karras_schedulers_shape(
self, num_inference_steps_for_strength=4, num_inference_steps_for_strength_for_iterations=5
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -527,13 +534,13 @@ class PipelineKarrasSchedulerTesterMixin:
inputs["num_inference_steps"] = 2
if "strength" in inputs:
inputs["num_inference_steps"] = 4
inputs["num_inference_steps"] = num_inference_steps_for_strength
inputs["strength"] = 0.5
outputs = []
for scheduler_enum in KarrasDiffusionSchedulers:
if "KDPM2" in scheduler_enum.name:
inputs["num_inference_steps"] = 5
inputs["num_inference_steps"] = num_inference_steps_for_strength_for_iterations
scheduler_cls = getattr(diffusers, scheduler_enum.name)
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)