diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index 0ad599c819..0df1e0e7a0 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -468,3 +468,83 @@ image
   
+ +### IP-Adapter masking + +Binary masks can be used to specify which portion of the output image should be assigned to an IP-Adapter. +For each input IP-Adapter image, a binary mask and an IP-Adapter must be provided. + +Before passing the masks to the pipeline, it's essential to preprocess them using [`IPAdapterMaskProcessor.preprocess()`]. + +> [!TIP] +> For optimal results, provide the output height and width to [`IPAdapterMaskProcessor.preprocess()`]. This ensures that masks with differing aspect ratios are appropriately stretched. If the input masks already match the aspect ratio of the generated image, specifying height and width can be omitted. + +Here an example with two masks: + +```py +from diffusers.image_processor import IPAdapterMaskProcessor + +mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png") +mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png") + +output_height = 1024 +output_width = 1024 + +processor = IPAdapterMaskProcessor() +masks = processor.preprocess([mask1, mask2], height=output_height, width=output_width) +``` + +
+
+ +
mask one
+
+
+ +
mask two
+
+
+ +If you have more than one IP-Adapter image, load them into a list, ensuring each image is assigned to a different IP-Adapter. + +```py +face_image1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png") +face_image2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png") + +ip_images =[[image1], [image2]] + +``` + +
+
+ +
ip adapter image one
+
+
+ +
ip adapter image two
+
+
+ +Pass preprocessed masks to the pipeline using `cross_attention_kwargs` as shown below: + +```py + +pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2) +pipeline.set_ip_adapter_scale([0.7] * 2) +generator = torch.Generator(device="cpu").manual_seed(0) +num_images=1 + +image = pipeline( + prompt="2 girls", + ip_adapter_image=ip_images, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=num_images, + generator=generator, cross_attention_kwargs={"ip_adapter_masks": masks} +).images[0] +``` + +
+    +
output image
+
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index f3a5cd3fb9..f6ccfda9fc 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings from typing import List, Optional, Tuple, Union import numpy as np import PIL.Image import torch +import torch.nn.functional as F from PIL import Image, ImageFilter, ImageOps from .configuration_utils import ConfigMixin, register_to_config @@ -882,3 +884,107 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): depth = self.binarize(depth) return rgb, depth + + +class IPAdapterMaskProcessor(VaeImageProcessor): + """ + Image processor for IP Adapter image masks. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the image to 0/1. + do_convert_grayscale (`bool`, *optional*, defaults to be `True`): + Whether to convert the images to grayscale format. + + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + resample: str = "lanczos", + do_normalize: bool = False, + do_binarize: bool = True, + do_convert_grayscale: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + resample=resample, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_grayscale=do_convert_grayscale, + ) + + @staticmethod + def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): + """ + Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. + If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. + + Args: + mask (`torch.FloatTensor`): + The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. + batch_size (`int`): + The batch size. + num_queries (`int`): + The number of queries. + value_embed_dim (`int`): + The dimensionality of the value embeddings. + + Returns: + `torch.FloatTensor`: + The downsampled mask tensor. + + """ + o_h = mask.shape[1] + o_w = mask.shape[2] + ratio = o_w / o_h + mask_h = int(math.sqrt(num_queries / ratio)) + mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) + mask_w = num_queries // mask_h + + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) + + # Repeat batch_size times + if mask_downsample.shape[0] < batch_size: + mask_downsample = mask_downsample.repeat(batch_size, 1, 1) + + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) + + downsampled_area = mask_h * mask_w + # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match + # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries + if downsampled_area < num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, + ) + mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) + # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries + if downsampled_area > num_queries: + warnings.warn( + "The aspect ratio of the mask does not match the aspect ratio of the output image. " + "Please update your masks or adjust the output size for optimal performance.", + UserWarning, + ) + mask_downsample = mask_downsample[:, :num_queries] + + # Repeat last dimension to match SDPA output shape + mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( + 1, 1, value_embed_dim + ) + + return mask_downsample diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 8acab015b3..1c008264ba 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F from torch import nn +from ..image_processor import IPAdapterMaskProcessor from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available from ..utils.torch_utils import maybe_allow_in_graph @@ -2107,12 +2108,13 @@ class IPAdapterAttnProcessor(nn.Module): def __call__( self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ): residual = hidden_states @@ -2167,9 +2169,22 @@ class IPAdapterAttnProcessor(nn.Module): hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + raise ValueError( + " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if len(ip_adapter_masks) != len(self.scale): + raise ValueError( + f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) @@ -2181,6 +2196,15 @@ class IPAdapterAttnProcessor(nn.Module): current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + if mask is not None: + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + + current_ip_hidden_states = current_ip_hidden_states * mask_downsample + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj @@ -2244,12 +2268,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): def __call__( self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - scale=1.0, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[torch.FloatTensor] = None, ): residual = hidden_states @@ -2318,9 +2343,22 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + raise ValueError( + " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if len(ip_adapter_masks) != len(self.scale): + raise ValueError( + f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): ip_key = to_k_ip(current_ip_hidden_states) ip_value = to_v_ip(current_ip_hidden_states) @@ -2339,6 +2377,15 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + if mask is not None: + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) + + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + + current_ip_hidden_states = current_ip_hidden_states * mask_downsample + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 11066253c5..6289ee887d 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -31,6 +31,7 @@ from diffusers import ( StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from diffusers.image_processor import IPAdapterMaskProcessor from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 from diffusers.utils import load_image from diffusers.utils.testing_utils import ( @@ -64,7 +65,7 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase): image_processor = CLIPImageProcessor.from_pretrained(repo_id) return image_processor - def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False): + def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False): image = load_image( "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" ) @@ -101,6 +102,22 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase): input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image}) + elif for_masks: + face_image1 = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png" + ) + face_image2 = load_image( + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png" + ) + mask1 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png") + mask2 = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png") + input_kwargs.update( + { + "ip_adapter_image": [[face_image1], [face_image2]], + "cross_attention_kwargs": {"ip_adapter_masks": [mask1, mask2]}, + } + ) + return input_kwargs @@ -465,3 +482,58 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + + def test_ip_adapter_single_mask(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors" + ) + pipeline.set_ip_adapter_scale(0.7) + + inputs = self.get_dummy_inputs(for_masks=True) + mask = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0] + processor = IPAdapterMaskProcessor() + mask = processor.preprocess(mask) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = mask + inputs["ip_adapter_image"] = inputs["ip_adapter_image"][0] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.7307304, 0.73450166, 0.73731124, 0.7377061, 0.7318013, 0.73720926, 0.74746597, 0.7409929, 0.74074936] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 + + def test_ip_adapter_multiple_masks(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2 + ) + pipeline.set_ip_adapter_scale([0.7] * 2) + + inputs = self.get_dummy_inputs(for_masks=True) + masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"] + processor = IPAdapterMaskProcessor() + masks = processor.preprocess(masks) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = masks + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4