mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Multi-image masking for single IP Adapter (#7499)
* Support multiimage masking --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -544,3 +544,33 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
def test_ip_adapter_multiple_masks_one_adapter(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
image_encoder=image_encoder,
|
||||
torch_dtype=self.dtype,
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
|
||||
)
|
||||
pipeline.set_ip_adapter_scale([[0.7, 0.7]])
|
||||
|
||||
inputs = self.get_dummy_inputs(for_masks=True)
|
||||
masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
|
||||
processor = IPAdapterMaskProcessor()
|
||||
masks = processor.preprocess(masks)
|
||||
masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])
|
||||
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks]
|
||||
ip_images = inputs["ip_adapter_image"]
|
||||
inputs["ip_adapter_image"] = [[image[0] for image in ip_images]]
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
expected_slice = np.array(
|
||||
[0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
@@ -238,6 +238,11 @@ class IPAdapterTesterMixin:
|
||||
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
|
||||
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
|
||||
|
||||
def _get_dummy_masks(self, input_size: int = 64):
|
||||
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
|
||||
_masks[0, :, :, : int(input_size / 2)] = 1
|
||||
return _masks
|
||||
|
||||
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
|
||||
parameters = inspect.signature(self.pipeline_class.__call__).parameters
|
||||
if "image" in parameters.keys() and "strength" in parameters.keys():
|
||||
@@ -365,6 +370,51 @@ class IPAdapterTesterMixin:
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
|
||||
def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
|
||||
sample_size = pipe.unet.config.get("sample_size", 32)
|
||||
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
|
||||
input_size = sample_size * (2 ** (len(block_out_channels) - 1))
|
||||
|
||||
# forward pass without ip adapter
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
output_without_adapter = pipe(**inputs)[0]
|
||||
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
|
||||
|
||||
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
|
||||
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
|
||||
|
||||
# forward pass with single ip adapter and masks, but scale=0 which should have no effect
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
|
||||
pipe.set_ip_adapter_scale(0.0)
|
||||
output_without_adapter_scale = pipe(**inputs)[0]
|
||||
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# forward pass with single ip adapter and masks, but with scale of adapter weights
|
||||
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
|
||||
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
|
||||
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
|
||||
pipe.set_ip_adapter_scale(42.0)
|
||||
output_with_adapter_scale = pipe(**inputs)[0]
|
||||
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
|
||||
|
||||
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
|
||||
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
|
||||
|
||||
self.assertLess(
|
||||
max_diff_without_adapter_scale,
|
||||
expected_max_diff,
|
||||
"Output without ip-adapter must be same as normal inference",
|
||||
)
|
||||
self.assertGreater(
|
||||
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
|
||||
)
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user