1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move IP Adapter Face ID to core (#7186)

* Switch to peft and multi proj layers

* Move Face ID loading and inference to core

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Fabio Rigano
2024-04-19 02:13:27 +02:00
committed by GitHub
parent e23c27e905
commit b5c8b555d7
10 changed files with 595 additions and 378 deletions

View File

@@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma
### Face model
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:
* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
> [!TIP]
>
> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
@@ -411,6 +409,56 @@ image
</div>
</div>
To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.
```py
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils import load_image
from insightface.app import FaceAnalysis
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(0.6)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
ref_images_embeds = []
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
faces = app.get(image)
image = torch.from_numpy(faces[0].normed_embedding)
ref_images_embeds.append(image.unsqueeze(0))
ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))
generator = torch.Generator(device="cpu").manual_seed(42)
images = pipeline(
prompt="A photo of a girl",
ip_adapter_image_embeds=[id_embeds],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=1,
generator=generator
).images
```
Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.
```py
clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
```
### Multi IP-Adapter
More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.

View File

@@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
```
### IP-Adapter Face ID models
The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
You need to install `insightface` and all its requirements to use these models.
<Tip warning={true}>
As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
</Tip>
```py
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
```
If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.
```py
from transformers import CLIPVisionModelWithProjection
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
torch_dtype=torch.float16,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
image_encoder=image_encoder,
torch_dtype=torch.float16
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
```

View File

@@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
You have to disable PEFT BACKEND in order to load weights.
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
```py
import diffusers
diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2

View File

@@ -26,7 +26,14 @@ from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import MultiIPAdapterImageProjection
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -45,300 +52,6 @@ from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LoRAIPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
lora_scale (`float`, defaults to 1.0):
the weight scale of LoRA.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(
self,
hidden_size,
cross_attention_dim=None,
rank=4,
network_alpha=None,
lora_scale=1.0,
scale=1.0,
num_tokens=4,
):
super().__init__()
self.rank = rank
self.lora_scale = lora_scale
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LoRAIPAdapterAttnProcessor2_0(nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
lora_scale (`float`, defaults to 1.0):
the weight scale of LoRA.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(
self,
hidden_size,
cross_attention_dim=None,
rank=4,
network_alpha=None,
lora_scale=1.0,
scale=1.0,
num_tokens=4,
):
super().__init__()
self.rank = rank
self.lora_scale = lora_scale
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
@@ -615,17 +328,13 @@ class IPAdapterFaceIDStableDiffusionPipeline(
return image_projection
def _load_ip_adapter_weights(self, state_dict):
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
)
num_image_text_embeds = 4
self.unet.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
lora_dict = {}
key_id = 0
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
@@ -642,94 +351,99 @@ class IPAdapterFaceIDStableDiffusionPipeline(
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
attn_module = self.unet
for n in name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=rank,
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=rank,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=rank,
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=rank,
)
)
value_dict = {}
for k, module in attn_module.named_children():
index = "."
if not hasattr(module, "set_lora_layer"):
index = ".0."
module = module[0]
lora_layer = getattr(module, "lora_layer")
for lora_name, w in lora_layer.state_dict().items():
value_dict.update(
{
f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][
f"{key_id}.{k}_lora.{lora_name}"
]
}
)
attn_module.load_state_dict(value_dict, strict=False)
attn_module.to(dtype=self.dtype, device=self.device)
lora_dict.update(
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
)
lora_dict.update(
{
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.down.weight"
]
}
)
lora_dict.update(
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
)
key_id += 1
else:
rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
attn_processor_class = (
LoRAIPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAIPAdapterAttnProcessor
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
rank=rank,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
lora_dict.update(
{f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
)
lora_dict.update(
{
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.down.weight"
]
}
)
lora_dict.update(
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
)
lora_dict.update(
{f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
)
value_dict = {}
value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 1
self.unet.set_attn_processor(attn_procs)
self.load_lora_weights(lora_dict, adapter_name="faceid")
self.set_adapters(["faceid"], adapter_weights=[1.0])
# convert IP-Adapter Image Projection layers to diffusers
image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]
self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.unet.config.encoder_hid_dim_type = "ip_image_proj"
def set_ip_adapter_scale(self, scale):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = [scale]
def _encode_prompt(
self,
@@ -1298,7 +1012,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
negative_image_embeds = torch.zeros_like(image_embeds)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = [image_embeds]
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -1319,7 +1033,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None
added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {}
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None

View File

@@ -21,6 +21,7 @@ from safetensors import safe_open
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
is_accelerate_available,
is_torch_version,
@@ -228,6 +229,18 @@ class IPAdapterMixin:
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
if not USE_PEFT_BACKEND:
logger.warning("PEFT backend is required to load these weights.")
else:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
def set_ip_adapter_scale(self, scale):
"""
Sets the conditioning scale between text and image.

View File

@@ -27,6 +27,8 @@ from torch import nn
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
IPAdapterFaceIDPlusImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
@@ -756,6 +758,90 @@ class UNet2DConditionLoadersMixin:
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value
elif "perceiver_resampler.proj_in.weight" in state_dict:
# IP-Adapter Face ID Plus
id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
with init_context():
image_projection = IPAdapterFaceIDPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
id_embeddings_dim=id_embeddings_dim,
)
for key, value in state_dict.items():
diffusers_name = key.replace("perceiver_resampler.", "")
diffusers_name = diffusers_name.replace("0.to", "attn.to")
diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
if "norm1" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
elif "norm2" in diffusers_name:
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
elif "to_kv" in diffusers_name:
v_chunk = value.chunk(2, dim=0)
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in diffusers_name:
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
elif "proj.0.weight" == diffusers_name:
updated_state_dict["proj.net.0.proj.weight"] = value
elif "proj.0.bias" == diffusers_name:
updated_state_dict["proj.net.0.proj.bias"] = value
elif "proj.2.weight" == diffusers_name:
updated_state_dict["proj.net.2.weight"] = value
elif "proj.2.bias" == diffusers_name:
updated_state_dict["proj.net.2.bias"] = value
else:
updated_state_dict[diffusers_name] = value
elif "norm.weight" in state_dict:
# IP-Adapter Face ID
id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
multiplier = id_embeddings_dim_out // id_embeddings_dim_in
norm_layer = "norm.weight"
cross_attention_dim = state_dict[norm_layer].shape[0]
num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
with init_context():
image_projection = IPAdapterFaceIDImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=id_embeddings_dim_in,
mult=multiplier,
num_tokens=num_tokens,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
updated_state_dict[diffusers_name] = value
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["latents"].shape[1]
@@ -847,6 +933,7 @@ class UNet2DConditionLoadersMixin:
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
@@ -859,6 +946,12 @@ class UNet2DConditionLoadersMixin:
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
# IP-Adapter Face ID Plus
num_image_text_embeds += [4]
elif "norm.weight" in state_dict["image_proj"]:
# IP-Adapter Face ID
num_image_text_embeds += [4]
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
@@ -910,6 +1003,59 @@ class UNet2DConditionLoadersMixin:
self.to(dtype=self.dtype, device=self.device)
def _load_ip_adapter_loras(self, state_dicts):
lora_dicts = {}
for key_id, name in enumerate(self.attn_processors.keys()):
for i, state_dict in enumerate(state_dicts):
if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
if i not in lora_dicts:
lora_dicts[i] = {}
lora_dicts[i].update(
{
f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_k_lora.down.weight"
]
}
)
lora_dicts[i].update(
{
f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_q_lora.down.weight"
]
}
)
lora_dicts[i].update(
{
f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_v_lora.down.weight"
]
}
)
lora_dicts[i].update(
{
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.down.weight"
]
}
)
lora_dicts[i].update(
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
)
lora_dicts[i].update(
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
)
lora_dicts[i].update(
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
)
lora_dicts[i].update(
{
f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
f"{key_id}.to_out_lora.up.weight"
]
}
)
return lora_dicts
class FromOriginalUNetMixin:
"""

View File

@@ -472,6 +472,22 @@ class IPAdapterFullImageProjection(nn.Module):
return self.norm(self.ff(image_embeds))
class IPAdapterFaceIDImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
from .attention import FeedForward
self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds: torch.FloatTensor):
x = self.ff(image_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
return self.norm(x)
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -794,13 +810,14 @@ class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
----
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
that is the same
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
hidden_dims (int):
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
Defaults to 16. num_queries (int):
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
of feedforward network hidden
layer channels. Defaults to 4.
"""
@@ -851,11 +868,8 @@ class IPAdapterPlusImageProjection(nn.Module):
"""Forward pass.
Args:
----
x (torch.Tensor): Input Tensor.
Returns:
-------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
@@ -875,6 +889,119 @@ class IPAdapterPlusImageProjection(nn.Module):
return self.norm_out(latents)
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,
embed_dims: int = 768,
dim_head: int = 64,
heads: int = 16,
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
self.attn = Attention(
query_dim=embed_dims,
dim_head=dim_head,
heads=heads,
out_bias=False,
)
self.ff = nn.Sequential(
nn.LayerNorm(embed_dims),
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
)
def forward(self, x, latents, residual):
encoder_hidden_states = self.ln0(x)
latents = self.ln1(latents)
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
latents = self.attn(latents, encoder_hidden_states) + residual
latents = self.ff(latents) + latents
return latents
class IPAdapterFaceIDPlusImageProjection(nn.Module):
"""FacePerceiverResampler of IP-Adapter Plus.
Args:
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
that is the same
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
hidden_dims (int):
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
ffproj_ratio (float): The expansion ratio of feedforward network hidden
layer channels (for ID embeddings). Defaults to 4.
"""
def __init__(
self,
embed_dims: int = 768,
output_dims: int = 768,
hidden_dims: int = 1280,
id_embeddings_dim: int = 512,
depth: int = 4,
dim_head: int = 64,
heads: int = 16,
num_tokens: int = 4,
num_queries: int = 8,
ffn_ratio: float = 4,
ffproj_ratio: int = 2,
) -> None:
super().__init__()
from .attention import FeedForward
self.num_tokens = num_tokens
self.embed_dim = embed_dims
self.clip_embeds = None
self.shortcut = False
self.shortcut_scale = 1.0
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
self.norm = nn.LayerNorm(embed_dims)
self.proj_in = nn.Linear(hidden_dims, embed_dims)
self.proj_out = nn.Linear(embed_dims, output_dims)
self.norm_out = nn.LayerNorm(output_dims)
self.layers = nn.ModuleList(
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
)
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
id_embeds (torch.Tensor): Input Tensor (ID embeds).
Returns:
torch.Tensor: Output Tensor.
"""
id_embeds = id_embeds.to(self.clip_embeds.dtype)
id_embeds = self.proj(id_embeds)
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
id_embeds = self.norm(id_embeds)
latents = id_embeds
clip_embeds = self.proj_in(self.clip_embeds)
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
for block in self.layers:
residual = latents
latents = block(x, latents, residual)
latents = self.proj_out(latents)
out = self.norm_out(latents)
if self.shortcut:
out = id_embeds + self.shortcut_scale * out
return out
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()

View File

@@ -30,7 +30,7 @@ from diffusers.models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model):
return ip_state_dict
def create_ip_adapter_faceid_state_dict(model):
# "ip_adapter" (cross-attention weights)
# no LoRA weights
ip_cross_attn_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
if cross_attention_dim is not None:
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
key_id += 2
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = IPAdapterFaceIDImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.0.weight": sd["ff.net.0.proj.weight"],
"proj.0.bias": sd["ff.net.0.proj.bias"],
"proj.2.weight": sd["ff.net.2.weight"],
"proj.2.bias": sd["ff.net.2.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True

View File

@@ -37,6 +37,7 @@ from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_flaky,
load_pt,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
@@ -306,6 +307,35 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_text_to_image_face_id(self):
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter-FaceID",
subfolder=None,
weight_name="ip-adapter-faceid_sd15.bin",
image_encoder_folder=None,
)
pipeline.set_ip_adapter_scale(0.7)
inputs = self.get_dummy_inputs()
id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
0
]
id_embeds = id_embeds.reshape((2, 1, 1, 512))
inputs["ip_adapter_image_embeds"] = [id_embeds]
inputs["ip_adapter_image"] = None
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953]
)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
@slow
@require_torch_gpu

View File

@@ -48,7 +48,10 @@ from ..models.autoencoders.test_models_vae import (
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -239,6 +242,9 @@ class IPAdapterTesterMixin:
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
@@ -416,6 +422,46 @@ class IPAdapterTesterMixin:
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin:
"""