diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index 0ef90c6dd9..10f8cbd999 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -506,22 +506,11 @@ import torch from diffusers import StableDiffusionPipeline, DDIMScheduler from diffusers.utils import load_image -noise_scheduler = DDIMScheduler( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1 -) - pipeline = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, - scheduler=noise_scheduler, ).to("cuda") - +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin") pipeline.set_ip_adapter_scale(0.7) @@ -550,6 +539,66 @@ image = pipeline( + +You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate. + +```python +import torch +from diffusers import AutoPipelineForText2Image, DDIMScheduler +from transformers import CLIPVisionModelWithProjection +from diffusers.utils import load_image + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "h94/IP-Adapter", + subfolder="models/image_encoder", + torch_dtype=torch.float16, +) + +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + image_encoder=image_encoder, +) +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"] +) +pipeline.set_ip_adapter_scale([0.7, 0.3]) +pipeline.enable_model_cpu_offload() + +face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png") +style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy" +style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)] + +generator = torch.Generator(device="cpu").manual_seed(0) + +image = pipeline( + prompt="wonderwoman", + ip_adapter_image=[style_images, face_image], + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=50, num_images_per_prompt=1, + generator=generator, +).images[0] +``` +
+    +
style input image
+
+ +
+
+ +
face input image
+
+
+ +
output image
+
+
+ + ### LCM-Lora You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights. diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index f2ac58cf93..679c46d57e 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pathlib import Path -from typing import Dict, Union +from typing import Dict, List, Union import torch from huggingface_hub.utils import validate_hf_hub_args @@ -45,9 +46,9 @@ class IPAdapterMixin: @validate_hf_hub_args def load_ip_adapter( self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - subfolder: str, - weight_name: str, + pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]], + subfolder: Union[str, List[str]], + weight_name: Union[str, List[str]], **kwargs, ): """ @@ -87,6 +88,26 @@ class IPAdapterMixin: The subfolder location of a model file within a larger model repository on the Hub or locally. """ + # handle the list inputs for multiple IP Adapters + if not isinstance(weight_name, list): + weight_name = [weight_name] + + if not isinstance(pretrained_model_name_or_path_or_dict, list): + pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict] + if len(pretrained_model_name_or_path_or_dict) == 1: + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name) + + if not isinstance(subfolder, list): + subfolder = [subfolder] + if len(subfolder) == 1: + subfolder = subfolder * len(weight_name) + + if len(weight_name) != len(pretrained_model_name_or_path_or_dict): + raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.") + + if len(weight_name) != len(subfolder): + raise ValueError("`weight_name` and `subfolder` must have the same length.") + # Load the main state dict first. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -100,61 +121,68 @@ class IPAdapterMixin: "file_type": "attn_procs_weights", "framework": "pytorch", } - - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = torch.load(model_file, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path_or_dict - - keys = list(state_dict.keys()) - if keys != ["image_proj", "ip_adapter"]: - raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - - # load CLIP image encoder here if it has not been registered to the pipeline yet - if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + state_dicts = [] + for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( + pretrained_model_name_or_path_or_dict, weight_name, subfolder + ): if not isinstance(pretrained_model_name_or_path_or_dict, dict): - logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") - image_encoder = CLIPVisionModelWithProjection.from_pretrained( + model_file = _get_model_file( pretrained_model_name_or_path_or_dict, - subfolder=Path(subfolder, "image_encoder").as_posix(), - ).to(self.device, dtype=self.dtype) - self.image_encoder = image_encoder - self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") else: - raise ValueError("`image_encoder` cannot be None when using IP Adapters.") + state_dict = pretrained_model_name_or_path_or_dict - # create feature extractor if it has not been registered to the pipeline yet - if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: - self.feature_extractor = CLIPImageProcessor() - self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"]) + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - # load ip-adapter into unet + state_dicts.append(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + pretrained_model_name_or_path_or_dict, + subfolder=Path(subfolder, "image_encoder").as_posix(), + ).to(self.device, dtype=self.dtype) + self.image_encoder = image_encoder + self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) + else: + raise ValueError("`image_encoder` cannot be None when using IP Adapters.") + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: + feature_extractor = CLIPImageProcessor() + self.register_modules(feature_extractor=feature_extractor) + + # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dict) + unet._load_ip_adapter_weights(state_dicts) def set_ip_adapter_scale(self, scale): + if not isinstance(scale, list): + scale = [scale] unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index d7c145e430..d359521e91 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -25,7 +25,12 @@ import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args from torch import nn -from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection +from ..models.embeddings import ( + ImageProjection, + IPAdapterFullImageProjection, + IPAdapterPlusImageProjection, + MultiIPAdapterImageProjection, +) from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( USE_PEFT_BACKEND, @@ -763,7 +768,7 @@ class UNet2DConditionLoadersMixin: image_projection.load_state_dict(updated_state_dict) return image_projection - def _load_ip_adapter_weights(self, state_dict): + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts): from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, @@ -771,20 +776,6 @@ class UNet2DConditionLoadersMixin: IPAdapterAttnProcessor2_0, ) - if "proj.weight" in state_dict["image_proj"]: - # IP-Adapter - num_image_text_embeds = 4 - elif "proj.3.weight" in state_dict["image_proj"]: - # IP-Adapter Full Face - num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token - else: - # IP-Adapter Plus - num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] - - # Set encoder_hid_proj after loading ip_adapter weights, - # because `IPAdapterPlusImageProjection` also has `attn_processors`. - self.encoder_hid_proj = None - # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 @@ -798,6 +789,7 @@ class UNet2DConditionLoadersMixin: elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = self.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: attn_processor_class = ( AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor @@ -807,6 +799,18 @@ class UNet2DConditionLoadersMixin: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) + num_image_text_embeds = [] + for state_dict in state_dicts: + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds += [4] + elif "proj.3.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face + num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + else: + # IP-Adapter Plus + num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, @@ -815,16 +819,31 @@ class UNet2DConditionLoadersMixin: ).to(dtype=self.dtype, device=self.device) value_dict = {} - for k, w in attn_procs[name].state_dict().items(): - value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + for i, state_dict in enumerate(state_dicts): + value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) attn_procs[name].load_state_dict(value_dict) key_id += 2 + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + # Set encoder_hid_proj after loading ip_adapter weights, + # because `IPAdapterPlusImageProjection` also has `attn_processors`. + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers - image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + image_projection_layer.to(device=self.device, dtype=self.dtype) + image_projection_layers.append(image_projection_layer) - self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ac9563e186..908946119d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2087,29 +2087,41 @@ class LoRAAttnAddedKVProcessor(nn.Module): class IPAdapterAttnProcessor(nn.Module): r""" - Attention processor for IP-Adapater. + Attention processor for Multiple IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. - num_tokens (`int`, defaults to 4): + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. - scale (`float`, defaults to 1.0): + scale (`float` or List[`float`], defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) def __call__( self, @@ -2120,10 +2132,24 @@ class IPAdapterAttnProcessor(nn.Module): temb=None, scale=1.0, ): - if scale != 1.0: - logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.") residual = hidden_states + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2148,13 +2174,6 @@ class IPAdapterAttnProcessor(nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # split hidden states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -2167,17 +2186,20 @@ class IPAdapterAttnProcessor(nn.Module): hidden_states = attn.batch_to_head_dim(hidden_states) # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2204,13 +2226,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. - num_tokens (`int`, defaults to 4): + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): The context length of the image features. - scale (`float`, defaults to 1.0): + scale (`float` or `List[float]`, defaults to 1.0): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -2220,11 +2242,23 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim + + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] self.num_tokens = num_tokens + + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_k_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) + self.to_v_ip = nn.ModuleList( + [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] + ) def __call__( self, @@ -2235,10 +2269,24 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): temb=None, scale=1.0, ): - if scale != 1.0: - logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.") residual = hidden_states + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." + ) + deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -2268,13 +2316,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - # split hidden states - end_pos = encoder_hidden_states.shape[1] - self.num_tokens - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - encoder_hidden_states[:, end_pos:, :], - ) - key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -2296,22 +2337,27 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.to(query.dtype) # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + self.scale * ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 293b751cb6..1ef035af10 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import List, Optional, Tuple, Union import numpy as np import torch from torch import nn -from ..utils import USE_PEFT_BACKEND +from ..utils import USE_PEFT_BACKEND, deprecate from .activations import get_activation from .attention_processor import Attention from .lora import LoRACompatibleLinear @@ -878,3 +878,38 @@ class IPAdapterPlusImageProjection(nn.Module): latents = self.proj_out(latents) return self.norm_out(latents) + + +class MultiIPAdapterImageProjection(nn.Module): + def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): + super().__init__() + self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.FloatTensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + deprecation_message = ( + "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." + " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning." + ) + deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 87297b5b5d..f45ef239be 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1074,8 +1074,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) - encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index ee1062ee81..5988e7657e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -426,6 +426,35 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -1002,12 +1031,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_videos_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6cd1658c59..f0c39952fd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -506,6 +506,35 @@ class StableDiffusionControlNetPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1083,12 +1112,9 @@ class StableDiffusionControlNetPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 6e00134591..f5e4775900 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -499,6 +499,35 @@ class StableDiffusionControlNetImg2ImgPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1087,12 +1116,9 @@ class StableDiffusionControlNetImg2ImgPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9f3009cede..bc6133c8b2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -624,6 +624,35 @@ class StableDiffusionControlNetInpaintPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1335,12 +1364,9 @@ class StableDiffusionControlNetInpaintPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 78793c2866..5165d193dc 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -515,6 +515,35 @@ class StableDiffusionXLControlNetPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1182,12 +1211,9 @@ class StableDiffusionXLControlNetPipeline( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare image if isinstance(controlnet, ControlNetModel): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 12ff9bbbfb..dda2f207b9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -564,6 +564,35 @@ class StableDiffusionXLControlNetImg2ImgPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1340,12 +1369,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index e772d8be2a..60707cc1e2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1280,8 +1280,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) - encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 63a54f5aa6..4146a35fb9 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -477,6 +477,35 @@ class LatentConsistencyModelImg2ImgPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -790,9 +819,8 @@ class LatentConsistencyModelImg2ImgPipeline( # do_classifier_free_guidance = guidance_scale > 1.0 if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) # 3. Encode input prompt diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index dc4ad60ce0..46b834f9ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -514,6 +514,34 @@ class StableDiffusionPipeline( return image_embeds, uncond_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -949,12 +977,9 @@ class StableDiffusionPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 45dbd1128d..f78cd383b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -528,6 +528,35 @@ class StableDiffusionImg2ImgPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1001,12 +1030,9 @@ class StableDiffusionImg2ImgPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Preprocess image image = self.image_processor.preprocess(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6751490abd..5d77341511 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -600,6 +600,35 @@ class StableDiffusionInpaintPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -1209,12 +1238,9 @@ class StableDiffusionInpaintPipeline( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. set timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 699bd10041..5b3aae13f4 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -438,6 +438,35 @@ class StableDiffusionLDM3DPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -654,12 +683,9 @@ class StableDiffusionLDM3DPipeline( do_classifier_free_guidance = guidance_scale > 1.0 if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index f0ef4b9f88..edf93839de 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -396,6 +396,35 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -669,12 +698,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM do_classifier_free_guidance = guidance_scale > 1.0 if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 3. Encode input prompt text_encoder_lora_scale = ( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index f9bafc9733..79f0aa379a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -549,6 +549,35 @@ class StableDiffusionXLPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1163,13 +1192,9 @@ class StableDiffusionXLPipeline( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = image_embeds.to(device) # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 1c22affba1..76416a6d33 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -766,6 +766,35 @@ class StableDiffusionXLImg2ImgPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + def _get_add_time_ids( self, original_size, @@ -1337,13 +1366,9 @@ class StableDiffusionXLImg2ImgPipeline( add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = image_embeds.to(device) # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index f9468adba9..248c990b2c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -487,6 +487,35 @@ class StableDiffusionXLInpaintPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, @@ -1685,13 +1714,9 @@ class StableDiffusionXLInpaintPipeline( add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = image_embeds.to(device) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 5ec35ddf07..1e97ce4da4 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -563,6 +563,35 @@ class StableDiffusionXLAdapterPipeline( return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1068,12 +1097,9 @@ class StableDiffusionXLAdapterPipeline( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 0e2a4765d6..1e8c4dca61 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -25,7 +25,11 @@ from parameterized import parameterized from pytest import mark from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor +from diffusers.models.attention_processor import ( + CustomDiffusionAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available @@ -73,8 +77,8 @@ def create_ip_adapter_state_dict(model): ).state_dict() ip_cross_attn_state_dict.update( { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], } ) @@ -124,8 +128,8 @@ def create_ip_adapter_plus_state_dict(model): ).state_dict() ip_cross_attn_state_dict.update( { - f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], - f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], } ) @@ -773,8 +777,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] + # for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim] image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) - inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} # make ip_adapter_1 and ip_adapter_2 ip_adapter_1 = create_ip_adapter_state_dict(model) @@ -785,7 +790,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) # forward pass ip_adapter_1 - model._load_ip_adapter_weights(ip_adapter_1) + model._load_ip_adapter_weights([ip_adapter_1]) assert model.config.encoder_hid_dim_type == "ip_image_proj" assert model.encoder_hid_proj is not None assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( @@ -796,18 +801,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test sample2 = model(**inputs_dict).sample # forward pass with ip_adapter_2 - model._load_ip_adapter_weights(ip_adapter_2) + model._load_ip_adapter_weights([ip_adapter_2]) with torch.no_grad(): sample3 = model(**inputs_dict).sample # forward pass with ip_adapter_1 again - model._load_ip_adapter_weights(ip_adapter_1) + model._load_ip_adapter_weights([ip_adapter_1]) with torch.no_grad(): sample4 = model(**inputs_dict).sample + # forward pass with multiple ip-adapters and multiple images + model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2]) + # set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1 + for attn_processor in model.attn_processors.values(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = [1, 0] + image_embeds_multi = image_embeds.repeat(1, 2, 1) + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]} + with torch.no_grad(): + sample5 = model(**inputs_dict).sample + + # forward pass with single ip-adapter & single image when image_embeds is not a list and a 2-d tensor + image_embeds = image_embeds.squeeze(1) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample6 = model(**inputs_dict).sample + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) def test_ip_adapter_plus(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -823,8 +849,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test # update inputs_dict for ip-adapter batch_size = inputs_dict["encoder_hidden_states"].shape[0] - image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) - inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + # for ip-adapter-plus image_embeds has shape [batch_size, num_image, sequence_length, embed_dim] + image_embeds = floats_tensor((batch_size, 1, 1, model.cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]} # make ip_adapter_1 and ip_adapter_2 ip_adapter_1 = create_ip_adapter_plus_state_dict(model) @@ -835,7 +862,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) # forward pass ip_adapter_1 - model._load_ip_adapter_weights(ip_adapter_1) + model._load_ip_adapter_weights([ip_adapter_1]) assert model.config.encoder_hid_dim_type == "ip_image_proj" assert model.encoder_hid_proj is not None assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( @@ -846,18 +873,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test sample2 = model(**inputs_dict).sample # forward pass with ip_adapter_2 - model._load_ip_adapter_weights(ip_adapter_2) + model._load_ip_adapter_weights([ip_adapter_2]) with torch.no_grad(): sample3 = model(**inputs_dict).sample # forward pass with ip_adapter_1 again - model._load_ip_adapter_weights(ip_adapter_1) + model._load_ip_adapter_weights([ip_adapter_1]) with torch.no_grad(): sample4 = model(**inputs_dict).sample + # forward pass with multiple ip-adapters and multiple images + model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2]) + # set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1 + for attn_processor in model.attn_processors.values(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = [1, 0] + image_embeds_multi = image_embeds.repeat(1, 2, 1, 1) + inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]} + with torch.no_grad(): + sample5 = model(**inputs_dict).sample + + # forward pass with single ip-adapter & single image when image_embeds is a 3-d tensor + image_embeds = image_embeds[:,].squeeze(1) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample6 = model(**inputs_dict).sample + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) @slow diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 4f0342e4fb..710dea3c2d 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -258,6 +258,27 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): ] assert processors == [True] * len(processors) + def test_multi(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter_sd15.bin", "ip-adapter-plus_sd15.bin"] + ) + pipeline.set_ip_adapter_scale([0.7, 0.3]) + + inputs = self.get_dummy_inputs() + ip_adapter_image = inputs["ip_adapter_image"] + inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789] + ) + assert np.allclose(image_slice, expected_slice, atol=1e-3) + @slow @require_torch_gpu