mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[IP-Adapter] Support multiple IP-Adapters (#6573)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Alvaro Somoza <somoza.alvaro@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -506,22 +506,11 @@ import torch
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import load_image
|
||||
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=noise_scheduler,
|
||||
).to("cuda")
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-full-face_sd15.bin")
|
||||
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
@@ -550,6 +539,66 @@ image = pipeline(
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
You can load multiple IP-Adapter models and use multiple reference images at the same time. In this example we use IP-Adapter-Plus face model to create a consistent character and also use IP-Adapter-Plus model along with 10 images to create a coherent style in the image we generate.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, DDIMScheduler
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="models/image_encoder",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="sdxl_models",
|
||||
weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
|
||||
)
|
||||
pipeline.set_ip_adapter_scale([0.7, 0.3])
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
|
||||
style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
|
||||
style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
image = pipeline(
|
||||
prompt="wonderwoman",
|
||||
ip_adapter_image=[style_images, face_image],
|
||||
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=50, num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_style_grid.png" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">style input image</figcaption>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">face input image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_multi_out.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">output image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
### LCM-Lora
|
||||
|
||||
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
|
||||
|
||||
@@ -11,8 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -45,9 +46,9 @@ class IPAdapterMixin:
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
subfolder: str,
|
||||
weight_name: str,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -87,6 +88,26 @@ class IPAdapterMixin:
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -100,61 +121,68 @@ class IPAdapterMixin:
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
self.feature_extractor = CLIPImageProcessor()
|
||||
self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"])
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# load ip-adapter into unet
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dict)
|
||||
unet._load_ip_adapter_weights(state_dicts)
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for attn_processor in unet.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
|
||||
@@ -25,7 +25,12 @@ import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFullImageProjection,
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -763,7 +768,7 @@ class UNet2DConditionLoadersMixin:
|
||||
image_projection.load_state_dict(updated_state_dict)
|
||||
return image_projection
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict):
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
@@ -771,20 +776,6 @@ class UNet2DConditionLoadersMixin:
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Full Face
|
||||
num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
|
||||
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 1
|
||||
@@ -798,6 +789,7 @@ class UNet2DConditionLoadersMixin:
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.config.block_out_channels[block_id]
|
||||
|
||||
if cross_attention_dim is None or "motion_modules" in name:
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
@@ -807,6 +799,18 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [4]
|
||||
elif "proj.3.weight" in state_dict["image_proj"]:
|
||||
# IP-Adapter Full Face
|
||||
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
|
||||
else:
|
||||
# IP-Adapter Plus
|
||||
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
||||
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -815,16 +819,31 @@ class UNet2DConditionLoadersMixin:
|
||||
).to(dtype=self.dtype, device=self.device)
|
||||
|
||||
value_dict = {}
|
||||
for k, w in attn_procs[name].state_dict().items():
|
||||
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
key_id += 2
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
# Set encoder_hid_proj after loading ip_adapter weights,
|
||||
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# convert IP-Adapter Image Projection layers to diffusers
|
||||
image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
|
||||
image_projection_layer.to(device=self.device, dtype=self.dtype)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
|
||||
@@ -2087,29 +2087,41 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
|
||||
class IPAdapterAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater.
|
||||
Attention processor for Multiple IP-Adapater.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, defaults to 4):
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
scale (`float` or List[`float`], defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -2120,10 +2132,24 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
temb=None,
|
||||
scale=1.0,
|
||||
):
|
||||
if scale != 1.0:
|
||||
logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -2148,13 +2174,6 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@@ -2167,17 +2186,20 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@@ -2204,13 +2226,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, defaults to 4):
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -2220,11 +2242,23 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -2235,10 +2269,24 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
temb=None,
|
||||
scale=1.0,
|
||||
):
|
||||
if scale != 1.0:
|
||||
logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -2268,13 +2316,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@@ -2296,22 +2337,27 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .lora import LoRACompatibleLinear
|
||||
@@ -878,3 +878,38 @@ class IPAdapterPlusImageProjection(nn.Module):
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
super().__init__()
|
||||
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
||||
|
||||
def forward(self, image_embeds: List[torch.FloatTensor]):
|
||||
projected_image_embeds = []
|
||||
|
||||
# currently, we accept `image_embeds` as
|
||||
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
|
||||
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
|
||||
if not isinstance(image_embeds, list):
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
|
||||
)
|
||||
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
||||
image_embeds = [image_embeds.unsqueeze(1)]
|
||||
|
||||
if len(image_embeds) != len(self.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
|
||||
)
|
||||
|
||||
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
|
||||
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
|
||||
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
|
||||
image_embed = image_projection_layer(image_embed)
|
||||
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
|
||||
|
||||
projected_image_embeds.append(image_embed)
|
||||
|
||||
return projected_image_embeds
|
||||
|
||||
@@ -1074,8 +1074,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
||||
image_embeds = self.encoder_hid_proj(image_embeds)
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -426,6 +426,35 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
@@ -1002,12 +1031,9 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_videos_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -506,6 +506,35 @@ class StableDiffusionControlNetPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1083,12 +1112,9 @@ class StableDiffusionControlNetPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
|
||||
@@ -499,6 +499,35 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1087,12 +1116,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
@@ -624,6 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1335,12 +1364,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
|
||||
@@ -515,6 +515,35 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -1182,12 +1211,9 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# 3.2 Encode ip_adapter_image
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
|
||||
@@ -564,6 +564,35 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -1340,12 +1369,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
|
||||
# 3.2 Encode ip_adapter_image
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image and controlnet_conditioning_image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
@@ -1280,8 +1280,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
||||
)
|
||||
image_embeds = added_cond_kwargs.get("image_embeds")
|
||||
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
|
||||
image_embeds = self.encoder_hid_proj(image_embeds)
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -477,6 +477,35 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -790,9 +819,8 @@ class LatentConsistencyModelImg2ImgPipeline(
|
||||
# do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
|
||||
# 3. Encode input prompt
|
||||
|
||||
@@ -514,6 +514,34 @@ class StableDiffusionPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -949,12 +977,9 @@ class StableDiffusionPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
@@ -528,6 +528,35 @@ class StableDiffusionImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1001,12 +1030,9 @@ class StableDiffusionImg2ImgPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
@@ -600,6 +600,35 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1209,12 +1238,9 @@ class StableDiffusionInpaintPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. set timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
@@ -438,6 +438,35 @@ class StableDiffusionLDM3DPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
@@ -654,12 +683,9 @@ class StableDiffusionLDM3DPipeline(
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
|
||||
@@ -396,6 +396,35 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -669,12 +698,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
|
||||
@@ -549,6 +549,35 @@ class StableDiffusionXLPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -1163,13 +1192,9 @@ class StableDiffusionXLPipeline(
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -766,6 +766,35 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
original_size,
|
||||
@@ -1337,13 +1366,9 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -487,6 +487,35 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -1685,13 +1714,9 @@ class StableDiffusionXLInpaintPipeline(
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
image_embeds = image_embeds.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -563,6 +563,35 @@ class StableDiffusionXLAdapterPipeline(
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
image_embeds = []
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -1068,12 +1097,9 @@ class StableDiffusionXLAdapterPipeline(
|
||||
|
||||
# 3.2 Encode ip_adapter_image
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image, device, batch_size * num_images_per_prompt
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
@@ -25,7 +25,11 @@ from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -73,8 +77,8 @@ def create_ip_adapter_state_dict(model):
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -124,8 +128,8 @@ def create_ip_adapter_plus_state_dict(model):
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
{
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
|
||||
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
|
||||
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -773,8 +777,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
# update inputs_dict for ip-adapter
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
|
||||
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
|
||||
# make ip_adapter_1 and ip_adapter_2
|
||||
ip_adapter_1 = create_ip_adapter_state_dict(model)
|
||||
@@ -785,7 +790,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
|
||||
|
||||
# forward pass ip_adapter_1
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
model._load_ip_adapter_weights([ip_adapter_1])
|
||||
assert model.config.encoder_hid_dim_type == "ip_image_proj"
|
||||
assert model.encoder_hid_proj is not None
|
||||
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
|
||||
@@ -796,18 +801,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
sample2 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_2
|
||||
model._load_ip_adapter_weights(ip_adapter_2)
|
||||
model._load_ip_adapter_weights([ip_adapter_2])
|
||||
with torch.no_grad():
|
||||
sample3 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_1 again
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
model._load_ip_adapter_weights([ip_adapter_1])
|
||||
with torch.no_grad():
|
||||
sample4 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with multiple ip-adapters and multiple images
|
||||
model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
|
||||
# set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
|
||||
for attn_processor in model.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = [1, 0]
|
||||
image_embeds_multi = image_embeds.repeat(1, 2, 1)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
|
||||
with torch.no_grad():
|
||||
sample5 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with single ip-adapter & single image when image_embeds is not a list and a 2-d tensor
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
|
||||
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
with torch.no_grad():
|
||||
sample6 = model(**inputs_dict).sample
|
||||
|
||||
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
|
||||
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -823,8 +849,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
|
||||
# update inputs_dict for ip-adapter
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
|
||||
# for ip-adapter-plus image_embeds has shape [batch_size, num_image, sequence_length, embed_dim]
|
||||
image_embeds = floats_tensor((batch_size, 1, 1, model.cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
|
||||
# make ip_adapter_1 and ip_adapter_2
|
||||
ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
|
||||
@@ -835,7 +862,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
|
||||
|
||||
# forward pass ip_adapter_1
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
model._load_ip_adapter_weights([ip_adapter_1])
|
||||
assert model.config.encoder_hid_dim_type == "ip_image_proj"
|
||||
assert model.encoder_hid_proj is not None
|
||||
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
|
||||
@@ -846,18 +873,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
sample2 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_2
|
||||
model._load_ip_adapter_weights(ip_adapter_2)
|
||||
model._load_ip_adapter_weights([ip_adapter_2])
|
||||
with torch.no_grad():
|
||||
sample3 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with ip_adapter_1 again
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
model._load_ip_adapter_weights([ip_adapter_1])
|
||||
with torch.no_grad():
|
||||
sample4 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with multiple ip-adapters and multiple images
|
||||
model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
|
||||
# set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
|
||||
for attn_processor in model.attn_processors.values():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
attn_processor.scale = [1, 0]
|
||||
image_embeds_multi = image_embeds.repeat(1, 2, 1, 1)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
|
||||
with torch.no_grad():
|
||||
sample5 = model(**inputs_dict).sample
|
||||
|
||||
# forward pass with single ip-adapter & single image when image_embeds is a 3-d tensor
|
||||
image_embeds = image_embeds[:,].squeeze(1)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
|
||||
|
||||
model._load_ip_adapter_weights(ip_adapter_1)
|
||||
with torch.no_grad():
|
||||
sample6 = model(**inputs_dict).sample
|
||||
|
||||
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
|
||||
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -258,6 +258,27 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
]
|
||||
assert processors == [True] * len(processors)
|
||||
|
||||
def test_multi(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter_sd15.bin", "ip-adapter-plus_sd15.bin"]
|
||||
)
|
||||
pipeline.set_ip_adapter_scale([0.7, 0.3])
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
ip_adapter_image = inputs["ip_adapter_image"]
|
||||
inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2]
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
expected_slice = np.array(
|
||||
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
|
||||
)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user