diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md
index 0ef90c6dd9..10f8cbd999 100644
--- a/docs/source/en/using-diffusers/loading_adapters.md
+++ b/docs/source/en/using-diffusers/loading_adapters.md
@@ -506,22 +506,11 @@ import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image
-noise_scheduler = DDIMScheduler(
- num_train_timesteps=1000,
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- steps_offset=1
-)
-
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
- scheduler=noise_scheduler,
).to("cuda")
-
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)
@@ -550,6 +539,66 @@ image = pipeline(
+
+You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate.
+
+```python
+import torch
+from diffusers import AutoPipelineForText2Image, DDIMScheduler
+from transformers import CLIPVisionModelWithProjection
+from diffusers.utils import load_image
+
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "h94/IP-Adapter",
+ subfolder="models/image_encoder",
+ torch_dtype=torch.float16,
+)
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ image_encoder=image_encoder,
+)
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
+)
+pipeline.set_ip_adapter_scale([0.7, 0.3])
+pipeline.enable_model_cpu_offload()
+
+face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
+style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
+style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]
+
+generator = torch.Generator(device="cpu").manual_seed(0)
+
+image = pipeline(
+ prompt="wonderwoman",
+ ip_adapter_image=[style_images, face_image],
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=50, num_images_per_prompt=1,
+ generator=generator,
+).images[0]
+```
+
+

+
style input image
+
+
+
+
+

+
face input image
+
+
+

+
output image
+
+
+
+
### LCM-Lora
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index f2ac58cf93..679c46d57e 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from pathlib import Path
-from typing import Dict, Union
+from typing import Dict, List, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
@@ -45,9 +46,9 @@ class IPAdapterMixin:
@validate_hf_hub_args
def load_ip_adapter(
self,
- pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
- subfolder: str,
- weight_name: str,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ subfolder: Union[str, List[str]],
+ weight_name: Union[str, List[str]],
**kwargs,
):
"""
@@ -87,6 +88,26 @@ class IPAdapterMixin:
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -100,61 +121,68 @@ class IPAdapterMixin:
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
-
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- model_file = _get_model_file(
- pretrained_model_name_or_path_or_dict,
- weights_name=weight_name,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- )
- if weight_name.endswith(".safetensors"):
- state_dict = {"image_proj": {}, "ip_adapter": {}}
- with safe_open(model_file, framework="pt", device="cpu") as f:
- for key in f.keys():
- if key.startswith("image_proj."):
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
- elif key.startswith("ip_adapter."):
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
- else:
- state_dict = torch.load(model_file, map_location="cpu")
- else:
- state_dict = pretrained_model_name_or_path_or_dict
-
- keys = list(state_dict.keys())
- if keys != ["image_proj", "ip_adapter"]:
- raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
-
- # load CLIP image encoder here if it has not been registered to the pipeline yet
- if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
- logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
- subfolder=Path(subfolder, "image_encoder").as_posix(),
- ).to(self.device, dtype=self.dtype)
- self.image_encoder = image_encoder
- self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(model_file, map_location="cpu")
else:
- raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
+ state_dict = pretrained_model_name_or_path_or_dict
- # create feature extractor if it has not been registered to the pipeline yet
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
- self.feature_extractor = CLIPImageProcessor()
- self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
+ keys = list(state_dict.keys())
+ if keys != ["image_proj", "ip_adapter"]:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
- # load ip-adapter into unet
+ state_dicts.append(state_dict)
+
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
+ if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ pretrained_model_name_or_path_or_dict,
+ subfolder=Path(subfolder, "image_encoder").as_posix(),
+ ).to(self.device, dtype=self.dtype)
+ self.image_encoder = image_encoder
+ self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
+ else:
+ raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
+
+ # create feature extractor if it has not been registered to the pipeline yet
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
+ feature_extractor = CLIPImageProcessor()
+ self.register_modules(feature_extractor=feature_extractor)
+
+ # load ip-adapter into unet
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
- unet._load_ip_adapter_weights(state_dict)
+ unet._load_ip_adapter_weights(state_dicts)
def set_ip_adapter_scale(self, scale):
+ if not isinstance(scale, list):
+ scale = [scale]
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index d7c145e430..d359521e91 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -25,7 +25,12 @@ import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
-from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
+from ..models.embeddings import (
+ ImageProjection,
+ IPAdapterFullImageProjection,
+ IPAdapterPlusImageProjection,
+ MultiIPAdapterImageProjection,
+)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
@@ -763,7 +768,7 @@ class UNet2DConditionLoadersMixin:
image_projection.load_state_dict(updated_state_dict)
return image_projection
- def _load_ip_adapter_weights(self, state_dict):
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -771,20 +776,6 @@ class UNet2DConditionLoadersMixin:
IPAdapterAttnProcessor2_0,
)
- if "proj.weight" in state_dict["image_proj"]:
- # IP-Adapter
- num_image_text_embeds = 4
- elif "proj.3.weight" in state_dict["image_proj"]:
- # IP-Adapter Full Face
- num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
- else:
- # IP-Adapter Plus
- num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
-
- # Set encoder_hid_proj after loading ip_adapter weights,
- # because `IPAdapterPlusImageProjection` also has `attn_processors`.
- self.encoder_hid_proj = None
-
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
@@ -798,6 +789,7 @@ class UNet2DConditionLoadersMixin:
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
+
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
@@ -807,6 +799,18 @@ class UNet2DConditionLoadersMixin:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
+ num_image_text_embeds = []
+ for state_dict in state_dicts:
+ if "proj.weight" in state_dict["image_proj"]:
+ # IP-Adapter
+ num_image_text_embeds += [4]
+ elif "proj.3.weight" in state_dict["image_proj"]:
+ # IP-Adapter Full Face
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
+ else:
+ # IP-Adapter Plus
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
+
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
@@ -815,16 +819,31 @@ class UNet2DConditionLoadersMixin:
).to(dtype=self.dtype, device=self.device)
value_dict = {}
- for k, w in attn_procs[name].state_dict().items():
- value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
+ for i, state_dict in enumerate(state_dicts):
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 2
+ return attn_procs
+
+ def _load_ip_adapter_weights(self, state_dicts):
+ if not isinstance(state_dicts, list):
+ state_dicts = [state_dicts]
+ # Set encoder_hid_proj after loading ip_adapter weights,
+ # because `IPAdapterPlusImageProjection` also has `attn_processors`.
+ self.encoder_hid_proj = None
+
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
self.set_attn_processor(attn_procs)
# convert IP-Adapter Image Projection layers to diffusers
- image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
+ image_projection_layers = []
+ for state_dict in state_dicts:
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
+ image_projection_layer.to(device=self.device, dtype=self.dtype)
+ image_projection_layers.append(image_projection_layer)
- self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index ac9563e186..908946119d 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2087,29 +2087,41 @@ class LoRAAttnAddedKVProcessor(nn.Module):
class IPAdapterAttnProcessor(nn.Module):
r"""
- Attention processor for IP-Adapater.
+ Attention processor for Multiple IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
- num_tokens (`int`, defaults to 4):
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
- scale (`float`, defaults to 1.0):
+ scale (`float` or List[`float`], defaults to 1.0):
the weight scale of image prompt.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
def __call__(
self,
@@ -2120,10 +2132,24 @@ class IPAdapterAttnProcessor(nn.Module):
temb=None,
scale=1.0,
):
- if scale != 1.0:
- logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
residual = hidden_states
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2148,13 +2174,6 @@ class IPAdapterAttnProcessor(nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- # split hidden states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
-
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -2167,17 +2186,20 @@ class IPAdapterAttnProcessor(nn.Module):
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
- hidden_states = hidden_states + self.scale * ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -2204,13 +2226,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
- num_tokens (`int`, defaults to 4):
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
- scale (`float`, defaults to 1.0):
+ scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@@ -2220,11 +2242,23 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
def __call__(
self,
@@ -2235,10 +2269,24 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
temb=None,
scale=1.0,
):
- if scale != 1.0:
- logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
residual = hidden_states
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -2268,13 +2316,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- # split hidden states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
-
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -2296,22 +2337,27 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
- hidden_states = hidden_states + self.scale * ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 293b751cb6..1ef035af10 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
-from typing import Optional
+from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
-from ..utils import USE_PEFT_BACKEND
+from ..utils import USE_PEFT_BACKEND, deprecate
from .activations import get_activation
from .attention_processor import Attention
from .lora import LoRACompatibleLinear
@@ -878,3 +878,38 @@ class IPAdapterPlusImageProjection(nn.Module):
latents = self.proj_out(latents)
return self.norm_out(latents)
+
+
+class MultiIPAdapterImageProjection(nn.Module):
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+ super().__init__()
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+ def forward(self, image_embeds: List[torch.FloatTensor]):
+ projected_image_embeds = []
+
+ # currently, we accept `image_embeds` as
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+ if not isinstance(image_embeds, list):
+ deprecation_message = (
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
+ )
+ deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
+ image_embeds = [image_embeds.unsqueeze(1)]
+
+ if len(image_embeds) != len(self.image_projection_layers):
+ raise ValueError(
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+ )
+
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+ image_embed = image_projection_layer(image_embed)
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+ projected_image_embeds.append(image_embed)
+
+ return projected_image_embeds
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 87297b5b5d..f45ef239be 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -1074,8 +1074,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process
sample = self.conv_in(sample)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index ee1062ee81..5988e7657e 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -426,6 +426,35 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
@@ -1002,12 +1031,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_videos_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 6cd1658c59..f0c39952fd 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -506,6 +506,35 @@ class StableDiffusionControlNetPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1083,12 +1112,9 @@ class StableDiffusionControlNetPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index 6e00134591..f5e4775900 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -499,6 +499,35 @@ class StableDiffusionControlNetImg2ImgPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1087,12 +1116,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 9f3009cede..bc6133c8b2 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -624,6 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1335,12 +1364,9 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 78793c2866..5165d193dc 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -515,6 +515,35 @@ class StableDiffusionXLControlNetPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1182,12 +1211,9 @@ class StableDiffusionXLControlNetPipeline(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 12ff9bbbfb..dda2f207b9 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -564,6 +564,35 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1340,12 +1369,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image and controlnet_conditioning_image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index e772d8be2a..60707cc1e2 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -1280,8 +1280,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 2. pre-process
sample = self.conv_in(sample)
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index 63a54f5aa6..4146a35fb9 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -477,6 +477,35 @@ class LatentConsistencyModelImg2ImgPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -790,9 +819,8 @@ class LatentConsistencyModelImg2ImgPipeline(
# do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
# 3. Encode input prompt
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index dc4ad60ce0..46b834f9ce 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -514,6 +514,34 @@ class StableDiffusionPipeline(
return image_embeds, uncond_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -949,12 +977,9 @@ class StableDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 45dbd1128d..f78cd383b8 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -528,6 +528,35 @@ class StableDiffusionImg2ImgPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1001,12 +1030,9 @@ class StableDiffusionImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 6751490abd..5d77341511 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -600,6 +600,35 @@ class StableDiffusionInpaintPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -1209,12 +1238,9 @@ class StableDiffusionInpaintPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 699bd10041..5b3aae13f4 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -438,6 +438,35 @@ class StableDiffusionLDM3DPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -654,12 +683,9 @@ class StableDiffusionLDM3DPipeline(
do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index f0ef4b9f88..edf93839de 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -396,6 +396,35 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -669,12 +698,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt
text_encoder_lora_scale = (
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index f9bafc9733..79f0aa379a 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -549,6 +549,35 @@ class StableDiffusionXLPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1163,13 +1192,9 @@ class StableDiffusionXLPipeline(
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 1c22affba1..76416a6d33 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -766,6 +766,35 @@ class StableDiffusionXLImg2ImgPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
def _get_add_time_ids(
self,
original_size,
@@ -1337,13 +1366,9 @@ class StableDiffusionXLImg2ImgPipeline(
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index f9468adba9..248c990b2c 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -487,6 +487,35 @@ class StableDiffusionXLInpaintPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
@@ -1685,13 +1714,9 @@ class StableDiffusionXLInpaintPipeline(
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 5ec35ddf07..1e97ce4da4 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -563,6 +563,35 @@ class StableDiffusionXLAdapterPipeline(
return image_embeds, uncond_image_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1068,12 +1097,9 @@ class StableDiffusionXLAdapterPipeline(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
- output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
- image_embeds, negative_image_embeds = self.encode_image(
- ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, device, batch_size * num_images_per_prompt
)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py
index 0e2a4765d6..1e8c4dca61 100644
--- a/tests/models/test_models_unet_2d_condition.py
+++ b/tests/models/test_models_unet_2d_condition.py
@@ -25,7 +25,11 @@ from parameterized import parameterized
from pytest import mark
from diffusers import UNet2DConditionModel
-from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
+from diffusers.models.attention_processor import (
+ CustomDiffusionAttnProcessor,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+)
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
@@ -73,8 +77,8 @@ def create_ip_adapter_state_dict(model):
).state_dict()
ip_cross_attn_state_dict.update(
{
- f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
- f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
@@ -124,8 +128,8 @@ def create_ip_adapter_plus_state_dict(model):
).state_dict()
ip_cross_attn_state_dict.update(
{
- f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
- f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
@@ -773,8 +777,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
+ # for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
- inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_state_dict(model)
@@ -785,7 +790,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
- model._load_ip_adapter_weights(ip_adapter_1)
+ model._load_ip_adapter_weights([ip_adapter_1])
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
@@ -796,18 +801,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
- model._load_ip_adapter_weights(ip_adapter_2)
+ model._load_ip_adapter_weights([ip_adapter_2])
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
- model._load_ip_adapter_weights(ip_adapter_1)
+ model._load_ip_adapter_weights([ip_adapter_1])
with torch.no_grad():
sample4 = model(**inputs_dict).sample
+ # forward pass with multiple ip-adapters and multiple images
+ model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
+ # set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
+ for attn_processor in model.attn_processors.values():
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ attn_processor.scale = [1, 0]
+ image_embeds_multi = image_embeds.repeat(1, 2, 1)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
+ with torch.no_grad():
+ sample5 = model(**inputs_dict).sample
+
+ # forward pass with single ip-adapter & single image when image_embeds is not a list and a 2-d tensor
+ image_embeds = image_embeds.squeeze(1)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
+
+ model._load_ip_adapter_weights(ip_adapter_1)
+ with torch.no_grad():
+ sample6 = model(**inputs_dict).sample
+
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
+ assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
+ assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -823,8 +849,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
- image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
- inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
+ # for ip-adapter-plus image_embeds has shape [batch_size, num_image, sequence_length, embed_dim]
+ image_embeds = floats_tensor((batch_size, 1, 1, model.cross_attention_dim)).to(torch_device)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
@@ -835,7 +862,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
- model._load_ip_adapter_weights(ip_adapter_1)
+ model._load_ip_adapter_weights([ip_adapter_1])
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
@@ -846,18 +873,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
- model._load_ip_adapter_weights(ip_adapter_2)
+ model._load_ip_adapter_weights([ip_adapter_2])
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
- model._load_ip_adapter_weights(ip_adapter_1)
+ model._load_ip_adapter_weights([ip_adapter_1])
with torch.no_grad():
sample4 = model(**inputs_dict).sample
+ # forward pass with multiple ip-adapters and multiple images
+ model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
+ # set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
+ for attn_processor in model.attn_processors.values():
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ attn_processor.scale = [1, 0]
+ image_embeds_multi = image_embeds.repeat(1, 2, 1, 1)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
+ with torch.no_grad():
+ sample5 = model(**inputs_dict).sample
+
+ # forward pass with single ip-adapter & single image when image_embeds is a 3-d tensor
+ image_embeds = image_embeds[:,].squeeze(1)
+ inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
+
+ model._load_ip_adapter_weights(ip_adapter_1)
+ with torch.no_grad():
+ sample6 = model(**inputs_dict).sample
+
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
+ assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
+ assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@slow
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index 4f0342e4fb..710dea3c2d 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -258,6 +258,27 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
]
assert processors == [True] * len(processors)
+ def test_multi(self):
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter_sd15.bin", "ip-adapter-plus_sd15.bin"]
+ )
+ pipeline.set_ip_adapter_scale([0.7, 0.3])
+
+ inputs = self.get_dummy_inputs()
+ ip_adapter_image = inputs["ip_adapter_image"]
+ inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2]
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ expected_slice = np.array(
+ [0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
+ )
+ assert np.allclose(image_slice, expected_slice, atol=1e-3)
+
@slow
@require_torch_gpu