1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-07-16 04:26:27 +02:00
parent 0a5c90ed47
commit d92855ddf0
3 changed files with 744 additions and 64 deletions

View File

@@ -478,23 +478,22 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
@property
def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs]
@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]
@property
def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs]
@property
def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks):
@@ -2843,8 +2842,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
**spec_dict,
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
sub_block.set_progress_bar_config(**kwargs)

View File

@@ -23,7 +23,8 @@ from diffusers import (
ComponentsManager,
LCMScheduler,
ModularPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -38,11 +39,9 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
from ..test_modular_pipelines_common import (
ModularIPAdapterTesterMixin,
ModularPipelineTesterMixin,
)
@@ -50,18 +49,25 @@ enable_full_determinism()
class StableDiffusionXLModularPipelineFastTests(
SDFunctionTesterMixin,
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
ModularIPAdapterTesterMixin,
ModularPipelineTesterMixin,
unittest.TestCase,
):
pipeline_class = StableDiffusionXLPipeline
params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - {"guidance_scale"}
pipeline_class = StableDiffusionXLModularPipeline
pipeline_blocks_class = StableDiffusionXLAutoBlocks
repo = "hf-internal-testing/tiny-sdxl-modular"
params = (TEXT_TO_IMAGE_PARAMS | IMAGE_INPAINTING_PARAMS) - {
"guidance_scale",
"prompt_embeds",
"negative_prompt_embeds",
}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS | IMAGE_INPAINTING_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
test_layerwise_casting = False
test_group_offloading = False
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_default_components(torch_dtype=torch_dtype)
return pipeline
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -78,7 +84,7 @@ class StableDiffusionXLModularPipelineFastTests(
def test_stable_diffusion_xl_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe")
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@@ -87,13 +93,17 @@ class StableDiffusionXLModularPipelineFastTests(
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.47])
expected_slice = np.array(
[0.5966781, 0.62939394, 0.48465094, 0.51573336, 0.57593524, 0.47035995, 0.53410417, 0.51436996, 0.47313565]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}"
)
def test_stable_diffusion_xl_euler_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe")
sd_pipe = self.get_pipeline()
sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config))
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
@@ -103,41 +113,23 @@ class StableDiffusionXLModularPipelineFastTests(
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
expected_slice = np.array(
[0.6880376, 0.6511651, 0.587455, 0.61763, 0.55432945, 0.52064973, 0.5783733, 0.54915607, 0.5460011]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe")
sd_pipe.update_components(scheduler=LCMScheduler.from_config(sd_pipe.scheduler.config))
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["num_inference_steps"]
inputs["timesteps"] = [999, 499]
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
f"image_slice: {image_slice.flatten()}, expected_slice: {expected_slice.flatten()}"
)
@require_torch_accelerator
def test_stable_diffusion_xl_offloads(self):
pipes = []
sd_pipe = ModularPipeline.from_pretrained(
"hf-internal-testing/tiny-sd-pipe",
).to(torch_device)
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
cm = ComponentsManager()
cm.enable_auto_cpu_offload(device=torch_device)
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe", components_manager=cm).to(
torch_device
)
sd_pipe = self.get_pipeline(components_manager=cm)
pipes.append(sd_pipe)
image_slices = []
@@ -148,21 +140,20 @@ class StableDiffusionXLModularPipelineFastTests(
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_stable_diffusion_xl_multi_prompts(self):
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
sd_pipe = self.get_pipeline().to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs, output="images")
image_slice_1 = output.images[0, -3:, -3:, -1]
image_slice_1 = output[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs, output="images")
image_slice_2 = output.images[0, -3:, -3:, -1]
image_slice_2 = output[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
@@ -171,7 +162,7 @@ class StableDiffusionXLModularPipelineFastTests(
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs, output="images")
image_slice_3 = output.images[0, -3:, -3:, -1]
image_slice_3 = output[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
@@ -180,14 +171,14 @@ class StableDiffusionXLModularPipelineFastTests(
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs, output="images")
image_slice_1 = output.images[0, -3:, -3:, -1]
image_slice_1 = output[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs, output="images")
image_slice_2 = output.images[0, -3:, -3:, -1]
image_slice_2 = output[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
@@ -197,15 +188,14 @@ class StableDiffusionXLModularPipelineFastTests(
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs, output="images")
image_slice_3 = output.images[0, -3:, -3:, -1]
image_slice_3 = output[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
def test_stable_diffusion_xl_negative_conditions(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
sd_pipe = sd_pipe.to(device)
sd_pipe = self.get_pipeline().to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
@@ -225,21 +215,24 @@ class StableDiffusionXLModularPipelineFastTests(
def test_stable_diffusion_xl_save_from_pretrained(self):
pipes = []
sd_pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
sd_pipe = self.get_pipeline().to(torch_device)
pipes.append(sd_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
sd_pipe.save_pretrained(tmpdirname)
sd_pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
sd_pipe.load_default_components(torch_dtype=torch.float32)
sd_pipe.to(torch_device)
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
pipe.unet.set_default_attn_processor()
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

View File

@@ -0,0 +1,689 @@
import gc
import unittest
from typing import Any, Callable, Dict, Union
import numpy as np
import torch
import diffusers
from diffusers import (
ClassifierFreeGuidance,
DiffusionPipeline,
)
from diffusers.loaders import ModularIPAdapterMixin
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
def check_same_shape(tensor_list):
shapes = [tensor.shape for tensor in tensor_list]
return all(shape == shapes[0] for shape in shapes[1:])
class ModularIPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_inputs_and_blocks(self):
blocks = self.pipeline_blocks_class()
parameters = blocks.input_names
assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter",
blocks.sub_blocks,
"pipeline must contain an IPAdapter block",
)
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
intermediate_parameters = blocks.intermediate_input_names
self.assertNotIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be removed from the `__call__` method",
)
self.assertNotIn(
"ip_adapter_image_embeds",
intermediate_parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
inputs["output_type"] = "np"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs, output="images")
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_multi_adapter_scale,
expected_max_diff,
"Output without multi-ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_multi_adapter_scale,
1e-2,
"Output with multi-ip-adapter scale must be different from normal inference",
)
def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
pipe.set_ip_adapter_scale(1.0)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)[0].unsqueeze(0)]
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
sample_size = pipe.unet.config.get("sample_size", 32)
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
input_size = sample_size * (2 ** (len(block_out_channels) - 1))
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs, output="images")
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter and masks, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter and masks, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
blocks = self.pipeline_blocks_class()
_ = blocks.sub_blocks.pop("ip_adapter")
pipe = blocks.init_pipeline(self.repo)
pipe.load_default_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs, output="images")
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs, output="images")
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs, output="images")
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
@require_torch
class ModularPipelineTesterMixin:
"""
This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each PyTorch pipeline, e.g. saving and loading the pipeline,
equivalence of dict and tuple outputs, etc.
"""
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
required_optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
# generator needs to be a intermediate input because it's mutable
required_intermediate_params = frozenset(
[
"generator",
]
)
def get_generator(self, seed):
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device).manual_seed(seed)
return generator
@property
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def repo(self) -> str:
raise NotImplementedError(
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
)
@property
def pipeline_blocks_class(self) -> Union[Callable, DiffusionPipeline]:
raise NotImplementedError(
"You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
"See existing pipeline tests for reference."
)
def get_pipeline(self):
raise NotImplementedError(
"You need to implement `get_pipeline(self)` in the child test class. "
"See existing pipeline tests for reference."
)
def get_dummy_inputs(self, device, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
)
@property
def params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `params` in the child test class. "
"`params` are checked for if all values are present in `__call__`'s signature."
" You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
" e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
"image pipelines, including prompts and prompt embedding overrides."
"If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
"do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
"with non-configurable height and width arguments should set the attribute as "
"`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
"See existing pipeline tests for reference."
)
@property
def batch_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `batch_params` in the child test class. "
"`batch_params` are the parameters required to be batched when passed to the pipeline's "
"`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
"`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
"set of batch arguments has minor changes from one of the common sets of batch arguments, "
"do not make modifications to the existing common sets of batch arguments. I.e. a text to "
"image pipeline `negative_prompt` is not batched should set the attribute as "
"`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
"See existing pipeline tests for reference."
)
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
parameters = pipe.blocks.input_names
optional_parameters = pipe.default_call_parameters
intermediate_parameters = pipe.blocks.intermediate_input_names
remaining_required_parameters = set()
for param in self.params:
if param not in parameters:
remaining_required_parameters.add(param)
self.assertTrue(
len(remaining_required_parameters) == 0,
f"Required parameters not present: {remaining_required_parameters}",
)
remaining_required_intermediate_parameters = set()
for param in self.required_intermediate_params:
if param not in intermediate_parameters:
remaining_required_intermediate_parameters.add(param)
self.assertTrue(
len(remaining_required_intermediate_parameters) == 0,
f"Required intermediate parameters not present: {remaining_required_intermediate_parameters}",
)
remaining_required_optional_parameters = set()
for param in self.required_optional_params:
if param not in optional_parameters:
remaining_required_optional_parameters.add(param)
self.assertTrue(
len(remaining_required_optional_parameters) == 0,
f"Required optional parameters not present: {remaining_required_optional_parameters}",
)
def test_inference_batch_consistent(self, batch_sizes=[2]):
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
def _test_inference_batch_consistent(
self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# prepare batched inputs
batched_inputs = []
for batch_size in batch_sizes:
batched_input = {}
batched_input.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
# make unequal batch sizes
batched_input[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
# make last batch super long
batched_input[name][-1] = 100 * "very long"
else:
batched_input[name] = batch_size * [value]
if batch_generator and "generator" in inputs:
batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_input["batch_size"] = batch_size
batched_inputs.append(batched_input)
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
output = pipe(**batched_input, output="images")
assert len(output) == batch_size
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4):
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
def _test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
pipe = self.get_pipeline()
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs, output="images")
output_batch = pipe(**batched_inputs, output="images")
assert output_batch.shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline(torch_dtype=torch.float32)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline(torch_dtype=torch.float16)
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"), output="images")
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device), output="images")
self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
assert images.shape[0] == batch_size * num_images_per_prompt
def test_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs(torch_device)
out_no_cfg = pipe(**inputs, output="images")
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
def assert_mean_pixel_difference(image, expected_image, expected_max_diff=10):
image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32)
expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < expected_max_diff, f"Error image deviates {avg_diff} pixels on average"